You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

aggregates.py 4.7 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """
  2. Classes to represent the default SQL aggregate functions
  3. """
  4. import copy
  5. import warnings
  6. from django.db.models.fields import FloatField, IntegerField
  7. from django.db.models.query_utils import RegisterLookupMixin
  8. from django.utils.deprecation import RemovedInDjango110Warning
  9. from django.utils.functional import cached_property
  10. __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
  11. warnings.warn(
  12. "django.db.models.sql.aggregates is deprecated. Use "
  13. "django.db.models.aggregates instead.",
  14. RemovedInDjango110Warning, stacklevel=2)
  15. class Aggregate(RegisterLookupMixin):
  16. """
  17. Default SQL Aggregate.
  18. """
  19. is_ordinal = False
  20. is_computed = False
  21. sql_template = '%(function)s(%(field)s)'
  22. def __init__(self, col, source=None, is_summary=False, **extra):
  23. """Instantiate an SQL aggregate
  24. * col is a column reference describing the subject field
  25. of the aggregate. It can be an alias, or a tuple describing
  26. a table and column name.
  27. * source is the underlying field or aggregate definition for
  28. the column reference. If the aggregate is not an ordinal or
  29. computed type, this reference is used to determine the coerced
  30. output type of the aggregate.
  31. * extra is a dictionary of additional data to provide for the
  32. aggregate definition
  33. Also utilizes the class variables:
  34. * sql_function, the name of the SQL function that implements the
  35. aggregate.
  36. * sql_template, a template string that is used to render the
  37. aggregate into SQL.
  38. * is_ordinal, a boolean indicating if the output of this aggregate
  39. is an integer (e.g., a count)
  40. * is_computed, a boolean indicating if this output of this aggregate
  41. is a computed float (e.g., an average), regardless of the input
  42. type.
  43. """
  44. self.col = col
  45. self.source = source
  46. self.is_summary = is_summary
  47. self.extra = extra
  48. # Follow the chain of aggregate sources back until you find an
  49. # actual field, or an aggregate that forces a particular output
  50. # type. This type of this field will be used to coerce values
  51. # retrieved from the database.
  52. tmp = self
  53. while tmp and isinstance(tmp, Aggregate):
  54. if getattr(tmp, 'is_ordinal', False):
  55. tmp = self._ordinal_aggregate_field
  56. elif getattr(tmp, 'is_computed', False):
  57. tmp = self._computed_aggregate_field
  58. else:
  59. tmp = tmp.source
  60. self.field = tmp
  61. # Two fake fields used to identify aggregate types in data-conversion operations.
  62. @cached_property
  63. def _ordinal_aggregate_field(self):
  64. return IntegerField()
  65. @cached_property
  66. def _computed_aggregate_field(self):
  67. return FloatField()
  68. def relabeled_clone(self, change_map):
  69. clone = copy.copy(self)
  70. if isinstance(self.col, (list, tuple)):
  71. clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
  72. return clone
  73. def as_sql(self, compiler, connection):
  74. "Return the aggregate, rendered as SQL with parameters."
  75. params = []
  76. if hasattr(self.col, 'as_sql'):
  77. field_name, params = self.col.as_sql(compiler, connection)
  78. elif isinstance(self.col, (list, tuple)):
  79. field_name = '.'.join(compiler(c) for c in self.col)
  80. else:
  81. field_name = compiler(self.col)
  82. substitutions = {
  83. 'function': self.sql_function,
  84. 'field': field_name
  85. }
  86. substitutions.update(self.extra)
  87. return self.sql_template % substitutions, params
  88. def get_group_by_cols(self):
  89. return []
  90. @property
  91. def output_field(self):
  92. return self.field
  93. class Avg(Aggregate):
  94. is_computed = True
  95. sql_function = 'AVG'
  96. class Count(Aggregate):
  97. is_ordinal = True
  98. sql_function = 'COUNT'
  99. sql_template = '%(function)s(%(distinct)s%(field)s)'
  100. def __init__(self, col, distinct=False, **extra):
  101. super(Count, self).__init__(col, distinct='DISTINCT ' if distinct else '', **extra)
  102. class Max(Aggregate):
  103. sql_function = 'MAX'
  104. class Min(Aggregate):
  105. sql_function = 'MIN'
  106. class StdDev(Aggregate):
  107. is_computed = True
  108. def __init__(self, col, sample=False, **extra):
  109. super(StdDev, self).__init__(col, **extra)
  110. self.sql_function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
  111. class Sum(Aggregate):
  112. sql_function = 'SUM'
  113. class Variance(Aggregate):
  114. is_computed = True
  115. def __init__(self, col, sample=False, **extra):
  116. super(Variance, self).__init__(col, **extra)
  117. self.sql_function = 'VAR_SAMP' if sample else 'VAR_POP'