You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

176 lines
6.4 KiB

  1. """
  2. Classes to represent the definitions of aggregate functions.
  3. """
  4. from django.core.exceptions import FieldError
  5. from django.db.models.expressions import Func, Star
  6. from django.db.models.fields import FloatField, IntegerField
  7. __all__ = [
  8. 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
  9. ]
  10. class Aggregate(Func):
  11. contains_aggregate = True
  12. name = None
  13. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  14. # Aggregates are not allowed in UPDATE queries, so ignore for_save
  15. c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
  16. if not summarize:
  17. expressions = c.get_source_expressions()
  18. for index, expr in enumerate(expressions):
  19. if expr.contains_aggregate:
  20. before_resolved = self.get_source_expressions()[index]
  21. name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
  22. raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
  23. c._patch_aggregate(query) # backward-compatibility support
  24. return c
  25. @property
  26. def default_alias(self):
  27. expressions = self.get_source_expressions()
  28. if len(expressions) == 1 and hasattr(expressions[0], 'name'):
  29. return '%s__%s' % (expressions[0].name, self.name.lower())
  30. raise TypeError("Complex expressions require an alias")
  31. def get_group_by_cols(self):
  32. return []
  33. def _patch_aggregate(self, query):
  34. """
  35. Helper method for patching 3rd party aggregates that do not yet support
  36. the new way of subclassing. This method will be removed in Django 1.10.
  37. add_to_query(query, alias, col, source, is_summary) will be defined on
  38. legacy aggregates which, in turn, instantiates the SQL implementation of
  39. the aggregate. In all the cases found, the general implementation of
  40. add_to_query looks like:
  41. def add_to_query(self, query, alias, col, source, is_summary):
  42. klass = SQLImplementationAggregate
  43. aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
  44. query.aggregates[alias] = aggregate
  45. By supplying a known alias, we can get the SQLAggregate out of the
  46. aggregates dict, and use the sql_function and sql_template attributes
  47. to patch *this* aggregate.
  48. """
  49. if not hasattr(self, 'add_to_query') or self.function is not None:
  50. return
  51. placeholder_alias = "_XXXXXXXX_"
  52. self.add_to_query(query, placeholder_alias, None, None, None)
  53. sql_aggregate = query.aggregates.pop(placeholder_alias)
  54. if 'sql_function' not in self.extra and hasattr(sql_aggregate, 'sql_function'):
  55. self.extra['function'] = sql_aggregate.sql_function
  56. if hasattr(sql_aggregate, 'sql_template'):
  57. self.extra['template'] = sql_aggregate.sql_template
  58. class Avg(Aggregate):
  59. function = 'AVG'
  60. name = 'Avg'
  61. def __init__(self, expression, **extra):
  62. output_field = extra.pop('output_field', FloatField())
  63. super(Avg, self).__init__(expression, output_field=output_field, **extra)
  64. def as_oracle(self, compiler, connection):
  65. if self.output_field.get_internal_type() == 'DurationField':
  66. expression = self.get_source_expressions()[0]
  67. from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
  68. return compiler.compile(
  69. SecondsToInterval(Avg(IntervalToSeconds(expression)))
  70. )
  71. return super(Avg, self).as_sql(compiler, connection)
  72. class Count(Aggregate):
  73. function = 'COUNT'
  74. name = 'Count'
  75. template = '%(function)s(%(distinct)s%(expressions)s)'
  76. def __init__(self, expression, distinct=False, **extra):
  77. if expression == '*':
  78. expression = Star()
  79. super(Count, self).__init__(
  80. expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
  81. def __repr__(self):
  82. return "{}({}, distinct={})".format(
  83. self.__class__.__name__,
  84. self.arg_joiner.join(str(arg) for arg in self.source_expressions),
  85. 'False' if self.extra['distinct'] == '' else 'True',
  86. )
  87. def convert_value(self, value, expression, connection, context):
  88. if value is None:
  89. return 0
  90. return int(value)
  91. class Max(Aggregate):
  92. function = 'MAX'
  93. name = 'Max'
  94. class Min(Aggregate):
  95. function = 'MIN'
  96. name = 'Min'
  97. class StdDev(Aggregate):
  98. name = 'StdDev'
  99. def __init__(self, expression, sample=False, **extra):
  100. self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
  101. super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
  102. def __repr__(self):
  103. return "{}({}, sample={})".format(
  104. self.__class__.__name__,
  105. self.arg_joiner.join(str(arg) for arg in self.source_expressions),
  106. 'False' if self.function == 'STDDEV_POP' else 'True',
  107. )
  108. def convert_value(self, value, expression, connection, context):
  109. if value is None:
  110. return value
  111. return float(value)
  112. class Sum(Aggregate):
  113. function = 'SUM'
  114. name = 'Sum'
  115. def as_oracle(self, compiler, connection):
  116. if self.output_field.get_internal_type() == 'DurationField':
  117. expression = self.get_source_expressions()[0]
  118. from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
  119. return compiler.compile(
  120. SecondsToInterval(Sum(IntervalToSeconds(expression)))
  121. )
  122. return super(Sum, self).as_sql(compiler, connection)
  123. class Variance(Aggregate):
  124. name = 'Variance'
  125. def __init__(self, expression, sample=False, **extra):
  126. self.function = 'VAR_SAMP' if sample else 'VAR_POP'
  127. super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
  128. def __repr__(self):
  129. return "{}({}, sample={})".format(
  130. self.__class__.__name__,
  131. self.arg_joiner.join(str(arg) for arg in self.source_expressions),
  132. 'False' if self.function == 'VAR_POP' else 'True',
  133. )
  134. def convert_value(self, value, expression, connection, context):
  135. if value is None:
  136. return value
  137. return float(value)