You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

208 lines
6.5 KiB

  1. from __future__ import unicode_literals
  2. import datetime
  3. import decimal
  4. import hashlib
  5. import logging
  6. from time import time
  7. from django.conf import settings
  8. from django.utils.encoding import force_bytes
  9. from django.utils.timezone import utc
  10. logger = logging.getLogger('django.db.backends')
  11. class CursorWrapper(object):
  12. def __init__(self, cursor, db):
  13. self.cursor = cursor
  14. self.db = db
  15. WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
  16. def __getattr__(self, attr):
  17. cursor_attr = getattr(self.cursor, attr)
  18. if attr in CursorWrapper.WRAP_ERROR_ATTRS:
  19. return self.db.wrap_database_errors(cursor_attr)
  20. else:
  21. return cursor_attr
  22. def __iter__(self):
  23. with self.db.wrap_database_errors:
  24. for item in self.cursor:
  25. yield item
  26. def __enter__(self):
  27. return self
  28. def __exit__(self, type, value, traceback):
  29. # Ticket #17671 - Close instead of passing thru to avoid backend
  30. # specific behavior. Catch errors liberally because errors in cleanup
  31. # code aren't useful.
  32. try:
  33. self.close()
  34. except self.db.Database.Error:
  35. pass
  36. # The following methods cannot be implemented in __getattr__, because the
  37. # code must run when the method is invoked, not just when it is accessed.
  38. def callproc(self, procname, params=None):
  39. self.db.validate_no_broken_transaction()
  40. with self.db.wrap_database_errors:
  41. if params is None:
  42. return self.cursor.callproc(procname)
  43. else:
  44. return self.cursor.callproc(procname, params)
  45. def execute(self, sql, params=None):
  46. self.db.validate_no_broken_transaction()
  47. with self.db.wrap_database_errors:
  48. if params is None:
  49. return self.cursor.execute(sql)
  50. else:
  51. return self.cursor.execute(sql, params)
  52. def executemany(self, sql, param_list):
  53. self.db.validate_no_broken_transaction()
  54. with self.db.wrap_database_errors:
  55. return self.cursor.executemany(sql, param_list)
  56. class CursorDebugWrapper(CursorWrapper):
  57. # XXX callproc isn't instrumented at this time.
  58. def execute(self, sql, params=None):
  59. start = time()
  60. try:
  61. return super(CursorDebugWrapper, self).execute(sql, params)
  62. finally:
  63. stop = time()
  64. duration = stop - start
  65. sql = self.db.ops.last_executed_query(self.cursor, sql, params)
  66. self.db.queries_log.append({
  67. 'sql': sql,
  68. 'time': "%.3f" % duration,
  69. })
  70. logger.debug('(%.3f) %s; args=%s' % (duration, sql, params),
  71. extra={'duration': duration, 'sql': sql, 'params': params}
  72. )
  73. def executemany(self, sql, param_list):
  74. start = time()
  75. try:
  76. return super(CursorDebugWrapper, self).executemany(sql, param_list)
  77. finally:
  78. stop = time()
  79. duration = stop - start
  80. try:
  81. times = len(param_list)
  82. except TypeError: # param_list could be an iterator
  83. times = '?'
  84. self.db.queries_log.append({
  85. 'sql': '%s times: %s' % (times, sql),
  86. 'time': "%.3f" % duration,
  87. })
  88. logger.debug('(%.3f) %s; args=%s' % (duration, sql, param_list),
  89. extra={'duration': duration, 'sql': sql, 'params': param_list}
  90. )
  91. ###############################################
  92. # Converters from database (string) to Python #
  93. ###############################################
  94. def typecast_date(s):
  95. return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null
  96. def typecast_time(s): # does NOT store time zone information
  97. if not s:
  98. return None
  99. hour, minutes, seconds = s.split(':')
  100. if '.' in seconds: # check whether seconds have a fractional part
  101. seconds, microseconds = seconds.split('.')
  102. else:
  103. microseconds = '0'
  104. return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
  105. def typecast_timestamp(s): # does NOT store time zone information
  106. # "2005-07-29 15:48:00.590358-05"
  107. # "2005-07-29 09:56:00-05"
  108. if not s:
  109. return None
  110. if ' ' not in s:
  111. return typecast_date(s)
  112. d, t = s.split()
  113. # Extract timezone information, if it exists. Currently we just throw
  114. # it away, but in the future we may make use of it.
  115. if '-' in t:
  116. t, tz = t.split('-', 1)
  117. tz = '-' + tz
  118. elif '+' in t:
  119. t, tz = t.split('+', 1)
  120. tz = '+' + tz
  121. else:
  122. tz = ''
  123. dates = d.split('-')
  124. times = t.split(':')
  125. seconds = times[2]
  126. if '.' in seconds: # check whether seconds have a fractional part
  127. seconds, microseconds = seconds.split('.')
  128. else:
  129. microseconds = '0'
  130. tzinfo = utc if settings.USE_TZ else None
  131. return datetime.datetime(int(dates[0]), int(dates[1]), int(dates[2]),
  132. int(times[0]), int(times[1]), int(seconds),
  133. int((microseconds + '000000')[:6]), tzinfo)
  134. def typecast_decimal(s):
  135. if s is None or s == '':
  136. return None
  137. return decimal.Decimal(s)
  138. ###############################################
  139. # Converters from Python to database (string) #
  140. ###############################################
  141. def rev_typecast_decimal(d):
  142. if d is None:
  143. return None
  144. return str(d)
  145. def truncate_name(name, length=None, hash_len=4):
  146. """Shortens a string to a repeatable mangled version with the given length.
  147. """
  148. if length is None or len(name) <= length:
  149. return name
  150. hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
  151. return '%s%s' % (name[:length - hash_len], hsh)
  152. def format_number(value, max_digits, decimal_places):
  153. """
  154. Formats a number into a string with the requisite number of digits and
  155. decimal places.
  156. """
  157. if value is None:
  158. return None
  159. if isinstance(value, decimal.Decimal):
  160. context = decimal.getcontext().copy()
  161. if max_digits is not None:
  162. context.prec = max_digits
  163. if decimal_places is not None:
  164. value = value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)
  165. else:
  166. context.traps[decimal.Rounded] = 1
  167. value = context.create_decimal(value)
  168. return "{:f}".format(value)
  169. if decimal_places is not None:
  170. return "%.*f" % (decimal_places, value)
  171. return "{:f}".format(value)