You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

995 lines
34 KiB

  1. import copy
  2. import datetime
  3. from django.conf import settings
  4. from django.core.exceptions import FieldError
  5. from django.db.backends import utils as backend_utils
  6. from django.db.models import fields
  7. from django.db.models.constants import LOOKUP_SEP
  8. from django.db.models.query_utils import Q, refs_aggregate
  9. from django.utils import six, timezone
  10. from django.utils.functional import cached_property
  11. class Combinable(object):
  12. """
  13. Provides the ability to combine one or two objects with
  14. some connector. For example F('foo') + F('bar').
  15. """
  16. # Arithmetic connectors
  17. ADD = '+'
  18. SUB = '-'
  19. MUL = '*'
  20. DIV = '/'
  21. POW = '^'
  22. # The following is a quoted % operator - it is quoted because it can be
  23. # used in strings that also have parameter substitution.
  24. MOD = '%%'
  25. # Bitwise operators - note that these are generated by .bitand()
  26. # and .bitor(), the '&' and '|' are reserved for boolean operator
  27. # usage.
  28. BITAND = '&'
  29. BITOR = '|'
  30. def _combine(self, other, connector, reversed, node=None):
  31. if not hasattr(other, 'resolve_expression'):
  32. # everything must be resolvable to an expression
  33. if isinstance(other, datetime.timedelta):
  34. other = DurationValue(other, output_field=fields.DurationField())
  35. else:
  36. other = Value(other)
  37. if reversed:
  38. return CombinedExpression(other, connector, self)
  39. return CombinedExpression(self, connector, other)
  40. #############
  41. # OPERATORS #
  42. #############
  43. def __add__(self, other):
  44. return self._combine(other, self.ADD, False)
  45. def __sub__(self, other):
  46. return self._combine(other, self.SUB, False)
  47. def __mul__(self, other):
  48. return self._combine(other, self.MUL, False)
  49. def __truediv__(self, other):
  50. return self._combine(other, self.DIV, False)
  51. def __div__(self, other): # Python 2 compatibility
  52. return type(self).__truediv__(self, other)
  53. def __mod__(self, other):
  54. return self._combine(other, self.MOD, False)
  55. def __pow__(self, other):
  56. return self._combine(other, self.POW, False)
  57. def __and__(self, other):
  58. raise NotImplementedError(
  59. "Use .bitand() and .bitor() for bitwise logical operations."
  60. )
  61. def bitand(self, other):
  62. return self._combine(other, self.BITAND, False)
  63. def __or__(self, other):
  64. raise NotImplementedError(
  65. "Use .bitand() and .bitor() for bitwise logical operations."
  66. )
  67. def bitor(self, other):
  68. return self._combine(other, self.BITOR, False)
  69. def __radd__(self, other):
  70. return self._combine(other, self.ADD, True)
  71. def __rsub__(self, other):
  72. return self._combine(other, self.SUB, True)
  73. def __rmul__(self, other):
  74. return self._combine(other, self.MUL, True)
  75. def __rtruediv__(self, other):
  76. return self._combine(other, self.DIV, True)
  77. def __rdiv__(self, other): # Python 2 compatibility
  78. return type(self).__rtruediv__(self, other)
  79. def __rmod__(self, other):
  80. return self._combine(other, self.MOD, True)
  81. def __rpow__(self, other):
  82. return self._combine(other, self.POW, True)
  83. def __rand__(self, other):
  84. raise NotImplementedError(
  85. "Use .bitand() and .bitor() for bitwise logical operations."
  86. )
  87. def __ror__(self, other):
  88. raise NotImplementedError(
  89. "Use .bitand() and .bitor() for bitwise logical operations."
  90. )
  91. class BaseExpression(object):
  92. """
  93. Base class for all query expressions.
  94. """
  95. # aggregate specific fields
  96. is_summary = False
  97. def __init__(self, output_field=None):
  98. self._output_field = output_field
  99. def get_db_converters(self, connection):
  100. return [self.convert_value] + self.output_field.get_db_converters(connection)
  101. def get_source_expressions(self):
  102. return []
  103. def set_source_expressions(self, exprs):
  104. assert len(exprs) == 0
  105. def _parse_expressions(self, *expressions):
  106. return [
  107. arg if hasattr(arg, 'resolve_expression') else (
  108. F(arg) if isinstance(arg, six.string_types) else Value(arg)
  109. ) for arg in expressions
  110. ]
  111. def as_sql(self, compiler, connection):
  112. """
  113. Responsible for returning a (sql, [params]) tuple to be included
  114. in the current query.
  115. Different backends can provide their own implementation, by
  116. providing an `as_{vendor}` method and patching the Expression:
  117. ```
  118. def override_as_sql(self, compiler, connection):
  119. # custom logic
  120. return super(Expression, self).as_sql(compiler, connection)
  121. setattr(Expression, 'as_' + connection.vendor, override_as_sql)
  122. ```
  123. Arguments:
  124. * compiler: the query compiler responsible for generating the query.
  125. Must have a compile method, returning a (sql, [params]) tuple.
  126. Calling compiler(value) will return a quoted `value`.
  127. * connection: the database connection used for the current query.
  128. Returns: (sql, params)
  129. Where `sql` is a string containing ordered sql parameters to be
  130. replaced with the elements of the list `params`.
  131. """
  132. raise NotImplementedError("Subclasses must implement as_sql()")
  133. @cached_property
  134. def contains_aggregate(self):
  135. for expr in self.get_source_expressions():
  136. if expr and expr.contains_aggregate:
  137. return True
  138. return False
  139. @cached_property
  140. def contains_column_references(self):
  141. for expr in self.get_source_expressions():
  142. if expr and expr.contains_column_references:
  143. return True
  144. return False
  145. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  146. """
  147. Provides the chance to do any preprocessing or validation before being
  148. added to the query.
  149. Arguments:
  150. * query: the backend query implementation
  151. * allow_joins: boolean allowing or denying use of joins
  152. in this query
  153. * reuse: a set of reusable joins for multijoins
  154. * summarize: a terminal aggregate clause
  155. * for_save: whether this expression about to be used in a save or update
  156. Returns: an Expression to be added to the query.
  157. """
  158. c = self.copy()
  159. c.is_summary = summarize
  160. c.set_source_expressions([
  161. expr.resolve_expression(query, allow_joins, reuse, summarize)
  162. for expr in c.get_source_expressions()
  163. ])
  164. return c
  165. def _prepare(self, field):
  166. """
  167. Hook used by Field.get_prep_lookup() to do custom preparation.
  168. """
  169. return self
  170. @property
  171. def field(self):
  172. return self.output_field
  173. @cached_property
  174. def output_field(self):
  175. """
  176. Returns the output type of this expressions.
  177. """
  178. if self._output_field_or_none is None:
  179. raise FieldError("Cannot resolve expression type, unknown output_field")
  180. return self._output_field_or_none
  181. @cached_property
  182. def _output_field_or_none(self):
  183. """
  184. Returns the output field of this expression, or None if no output type
  185. can be resolved. Note that the 'output_field' property will raise
  186. FieldError if no type can be resolved, but this attribute allows for
  187. None values.
  188. """
  189. if self._output_field is None:
  190. self._resolve_output_field()
  191. return self._output_field
  192. def _resolve_output_field(self):
  193. """
  194. Attempts to infer the output type of the expression. If the output
  195. fields of all source fields match then we can simply infer the same
  196. type here. This isn't always correct, but it makes sense most of the
  197. time.
  198. Consider the difference between `2 + 2` and `2 / 3`. Inferring
  199. the type here is a convenience for the common case. The user should
  200. supply their own output_field with more complex computations.
  201. If a source does not have an `_output_field` then we exclude it from
  202. this check. If all sources are `None`, then an error will be thrown
  203. higher up the stack in the `output_field` property.
  204. """
  205. if self._output_field is None:
  206. sources = self.get_source_fields()
  207. num_sources = len(sources)
  208. if num_sources == 0:
  209. self._output_field = None
  210. else:
  211. for source in sources:
  212. if self._output_field is None:
  213. self._output_field = source
  214. if source is not None and not isinstance(self._output_field, source.__class__):
  215. raise FieldError(
  216. "Expression contains mixed types. You must set output_field")
  217. def convert_value(self, value, expression, connection, context):
  218. """
  219. Expressions provide their own converters because users have the option
  220. of manually specifying the output_field which may be a different type
  221. from the one the database returns.
  222. """
  223. field = self.output_field
  224. internal_type = field.get_internal_type()
  225. if value is None:
  226. return value
  227. elif internal_type == 'FloatField':
  228. return float(value)
  229. elif internal_type.endswith('IntegerField'):
  230. return int(value)
  231. elif internal_type == 'DecimalField':
  232. return backend_utils.typecast_decimal(value)
  233. return value
  234. def get_lookup(self, lookup):
  235. return self.output_field.get_lookup(lookup)
  236. def get_transform(self, name):
  237. return self.output_field.get_transform(name)
  238. def relabeled_clone(self, change_map):
  239. clone = self.copy()
  240. clone.set_source_expressions(
  241. [e.relabeled_clone(change_map) for e in self.get_source_expressions()])
  242. return clone
  243. def copy(self):
  244. c = copy.copy(self)
  245. c.copied = True
  246. return c
  247. def refs_aggregate(self, existing_aggregates):
  248. """
  249. Does this expression contain a reference to some of the
  250. existing aggregates? If so, returns the aggregate and also
  251. the lookup parts that *weren't* found. So, if
  252. existing_aggregates = {'max_id': Max('id')}
  253. self.name = 'max_id'
  254. queryset.filter(max_id__range=[10,100])
  255. then this method will return Max('id') and those parts of the
  256. name that weren't found. In this case `max_id` is found and the range
  257. portion is returned as ('range',).
  258. """
  259. for node in self.get_source_expressions():
  260. agg, lookup = node.refs_aggregate(existing_aggregates)
  261. if agg:
  262. return agg, lookup
  263. return False, ()
  264. def get_group_by_cols(self):
  265. if not self.contains_aggregate:
  266. return [self]
  267. cols = []
  268. for source in self.get_source_expressions():
  269. cols.extend(source.get_group_by_cols())
  270. return cols
  271. def get_source_fields(self):
  272. """
  273. Returns the underlying field types used by this
  274. aggregate.
  275. """
  276. return [e._output_field_or_none for e in self.get_source_expressions()]
  277. def asc(self):
  278. return OrderBy(self)
  279. def desc(self):
  280. return OrderBy(self, descending=True)
  281. def reverse_ordering(self):
  282. return self
  283. def flatten(self):
  284. """
  285. Recursively yield this expression and all subexpressions, in
  286. depth-first order.
  287. """
  288. yield self
  289. for expr in self.get_source_expressions():
  290. if expr:
  291. for inner_expr in expr.flatten():
  292. yield inner_expr
  293. class Expression(BaseExpression, Combinable):
  294. """
  295. An expression that can be combined with other expressions.
  296. """
  297. pass
  298. class CombinedExpression(Expression):
  299. def __init__(self, lhs, connector, rhs, output_field=None):
  300. super(CombinedExpression, self).__init__(output_field=output_field)
  301. self.connector = connector
  302. self.lhs = lhs
  303. self.rhs = rhs
  304. def __repr__(self):
  305. return "<{}: {}>".format(self.__class__.__name__, self)
  306. def __str__(self):
  307. return "{} {} {}".format(self.lhs, self.connector, self.rhs)
  308. def get_source_expressions(self):
  309. return [self.lhs, self.rhs]
  310. def set_source_expressions(self, exprs):
  311. self.lhs, self.rhs = exprs
  312. def as_sql(self, compiler, connection):
  313. try:
  314. lhs_output = self.lhs.output_field
  315. except FieldError:
  316. lhs_output = None
  317. try:
  318. rhs_output = self.rhs.output_field
  319. except FieldError:
  320. rhs_output = None
  321. if (not connection.features.has_native_duration_field and
  322. ((lhs_output and lhs_output.get_internal_type() == 'DurationField')
  323. or (rhs_output and rhs_output.get_internal_type() == 'DurationField'))):
  324. return DurationExpression(self.lhs, self.connector, self.rhs).as_sql(compiler, connection)
  325. expressions = []
  326. expression_params = []
  327. sql, params = compiler.compile(self.lhs)
  328. expressions.append(sql)
  329. expression_params.extend(params)
  330. sql, params = compiler.compile(self.rhs)
  331. expressions.append(sql)
  332. expression_params.extend(params)
  333. # order of precedence
  334. expression_wrapper = '(%s)'
  335. sql = connection.ops.combine_expression(self.connector, expressions)
  336. return expression_wrapper % sql, expression_params
  337. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  338. c = self.copy()
  339. c.is_summary = summarize
  340. c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  341. c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  342. return c
  343. class DurationExpression(CombinedExpression):
  344. def compile(self, side, compiler, connection):
  345. if not isinstance(side, DurationValue):
  346. try:
  347. output = side.output_field
  348. except FieldError:
  349. pass
  350. else:
  351. if output.get_internal_type() == 'DurationField':
  352. sql, params = compiler.compile(side)
  353. return connection.ops.format_for_duration_arithmetic(sql), params
  354. return compiler.compile(side)
  355. def as_sql(self, compiler, connection):
  356. connection.ops.check_expression_support(self)
  357. expressions = []
  358. expression_params = []
  359. sql, params = self.compile(self.lhs, compiler, connection)
  360. expressions.append(sql)
  361. expression_params.extend(params)
  362. sql, params = self.compile(self.rhs, compiler, connection)
  363. expressions.append(sql)
  364. expression_params.extend(params)
  365. # order of precedence
  366. expression_wrapper = '(%s)'
  367. sql = connection.ops.combine_duration_expression(self.connector, expressions)
  368. return expression_wrapper % sql, expression_params
  369. class F(Combinable):
  370. """
  371. An object capable of resolving references to existing query objects.
  372. """
  373. def __init__(self, name):
  374. """
  375. Arguments:
  376. * name: the name of the field this expression references
  377. """
  378. self.name = name
  379. def __repr__(self):
  380. return "{}({})".format(self.__class__.__name__, self.name)
  381. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  382. return query.resolve_ref(self.name, allow_joins, reuse, summarize)
  383. def refs_aggregate(self, existing_aggregates):
  384. return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates)
  385. def asc(self):
  386. return OrderBy(self)
  387. def desc(self):
  388. return OrderBy(self, descending=True)
  389. class Func(Expression):
  390. """
  391. An SQL function call.
  392. """
  393. function = None
  394. template = '%(function)s(%(expressions)s)'
  395. arg_joiner = ', '
  396. def __init__(self, *expressions, **extra):
  397. output_field = extra.pop('output_field', None)
  398. super(Func, self).__init__(output_field=output_field)
  399. self.source_expressions = self._parse_expressions(*expressions)
  400. self.extra = extra
  401. def __repr__(self):
  402. args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  403. extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items())
  404. if extra:
  405. return "{}({}, {})".format(self.__class__.__name__, args, extra)
  406. return "{}({})".format(self.__class__.__name__, args)
  407. def get_source_expressions(self):
  408. return self.source_expressions
  409. def set_source_expressions(self, exprs):
  410. self.source_expressions = exprs
  411. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  412. c = self.copy()
  413. c.is_summary = summarize
  414. for pos, arg in enumerate(c.source_expressions):
  415. c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  416. return c
  417. def as_sql(self, compiler, connection, function=None, template=None):
  418. connection.ops.check_expression_support(self)
  419. sql_parts = []
  420. params = []
  421. for arg in self.source_expressions:
  422. arg_sql, arg_params = compiler.compile(arg)
  423. sql_parts.append(arg_sql)
  424. params.extend(arg_params)
  425. if function is None:
  426. self.extra['function'] = self.extra.get('function', self.function)
  427. else:
  428. self.extra['function'] = function
  429. self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
  430. template = template or self.extra.get('template', self.template)
  431. return template % self.extra, params
  432. def as_sqlite(self, *args, **kwargs):
  433. sql, params = self.as_sql(*args, **kwargs)
  434. try:
  435. if self.output_field.get_internal_type() == 'DecimalField':
  436. sql = 'CAST(%s AS NUMERIC)' % sql
  437. except FieldError:
  438. pass
  439. return sql, params
  440. def copy(self):
  441. copy = super(Func, self).copy()
  442. copy.source_expressions = self.source_expressions[:]
  443. copy.extra = self.extra.copy()
  444. return copy
  445. class Value(Expression):
  446. """
  447. Represents a wrapped value as a node within an expression
  448. """
  449. def __init__(self, value, output_field=None):
  450. """
  451. Arguments:
  452. * value: the value this expression represents. The value will be
  453. added into the sql parameter list and properly quoted.
  454. * output_field: an instance of the model field type that this
  455. expression will return, such as IntegerField() or CharField().
  456. """
  457. super(Value, self).__init__(output_field=output_field)
  458. self.value = value
  459. def __repr__(self):
  460. return "{}({})".format(self.__class__.__name__, self.value)
  461. def as_sql(self, compiler, connection):
  462. connection.ops.check_expression_support(self)
  463. val = self.value
  464. # check _output_field to avoid triggering an exception
  465. if self._output_field is not None:
  466. if self.for_save:
  467. val = self.output_field.get_db_prep_save(val, connection=connection)
  468. else:
  469. val = self.output_field.get_db_prep_value(val, connection=connection)
  470. if val is None:
  471. # cx_Oracle does not always convert None to the appropriate
  472. # NULL type (like in case expressions using numbers), so we
  473. # use a literal SQL NULL
  474. return 'NULL', []
  475. return '%s', [val]
  476. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  477. c = super(Value, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
  478. c.for_save = for_save
  479. return c
  480. def get_group_by_cols(self):
  481. return []
  482. class DurationValue(Value):
  483. def as_sql(self, compiler, connection):
  484. connection.ops.check_expression_support(self)
  485. if (connection.features.has_native_duration_field and
  486. connection.features.driver_supports_timedelta_args):
  487. return super(DurationValue, self).as_sql(compiler, connection)
  488. return connection.ops.date_interval_sql(self.value)
  489. class RawSQL(Expression):
  490. def __init__(self, sql, params, output_field=None):
  491. if output_field is None:
  492. output_field = fields.Field()
  493. self.sql, self.params = sql, params
  494. super(RawSQL, self).__init__(output_field=output_field)
  495. def __repr__(self):
  496. return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
  497. def as_sql(self, compiler, connection):
  498. return '(%s)' % self.sql, self.params
  499. def get_group_by_cols(self):
  500. return [self]
  501. class Star(Expression):
  502. def __repr__(self):
  503. return "'*'"
  504. def as_sql(self, compiler, connection):
  505. return '*', []
  506. class Random(Expression):
  507. def __init__(self):
  508. super(Random, self).__init__(output_field=fields.FloatField())
  509. def __repr__(self):
  510. return "Random()"
  511. def as_sql(self, compiler, connection):
  512. return connection.ops.random_function_sql(), []
  513. class Col(Expression):
  514. contains_column_references = True
  515. def __init__(self, alias, target, output_field=None):
  516. if output_field is None:
  517. output_field = target
  518. super(Col, self).__init__(output_field=output_field)
  519. self.alias, self.target = alias, target
  520. def __repr__(self):
  521. return "{}({}, {})".format(
  522. self.__class__.__name__, self.alias, self.target)
  523. def as_sql(self, compiler, connection):
  524. qn = compiler.quote_name_unless_alias
  525. return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
  526. def relabeled_clone(self, relabels):
  527. return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
  528. def get_group_by_cols(self):
  529. return [self]
  530. def get_db_converters(self, connection):
  531. if self.target == self.output_field:
  532. return self.output_field.get_db_converters(connection)
  533. return (self.output_field.get_db_converters(connection) +
  534. self.target.get_db_converters(connection))
  535. class Ref(Expression):
  536. """
  537. Reference to column alias of the query. For example, Ref('sum_cost') in
  538. qs.annotate(sum_cost=Sum('cost')) query.
  539. """
  540. def __init__(self, refs, source):
  541. super(Ref, self).__init__()
  542. self.refs, self.source = refs, source
  543. def __repr__(self):
  544. return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
  545. def get_source_expressions(self):
  546. return [self.source]
  547. def set_source_expressions(self, exprs):
  548. self.source, = exprs
  549. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  550. # The sub-expression `source` has already been resolved, as this is
  551. # just a reference to the name of `source`.
  552. return self
  553. def relabeled_clone(self, relabels):
  554. return self
  555. def as_sql(self, compiler, connection):
  556. return "%s" % connection.ops.quote_name(self.refs), []
  557. def get_group_by_cols(self):
  558. return [self]
  559. class ExpressionWrapper(Expression):
  560. """
  561. An expression that can wrap another expression so that it can provide
  562. extra context to the inner expression, such as the output_field.
  563. """
  564. def __init__(self, expression, output_field):
  565. super(ExpressionWrapper, self).__init__(output_field=output_field)
  566. self.expression = expression
  567. def set_source_expressions(self, exprs):
  568. self.expression = exprs[0]
  569. def get_source_expressions(self):
  570. return [self.expression]
  571. def as_sql(self, compiler, connection):
  572. return self.expression.as_sql(compiler, connection)
  573. def __repr__(self):
  574. return "{}({})".format(self.__class__.__name__, self.expression)
  575. class When(Expression):
  576. template = 'WHEN %(condition)s THEN %(result)s'
  577. def __init__(self, condition=None, then=None, **lookups):
  578. if lookups and condition is None:
  579. condition, lookups = Q(**lookups), None
  580. if condition is None or not isinstance(condition, Q) or lookups:
  581. raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
  582. super(When, self).__init__(output_field=None)
  583. self.condition = condition
  584. self.result = self._parse_expressions(then)[0]
  585. def __str__(self):
  586. return "WHEN %r THEN %r" % (self.condition, self.result)
  587. def __repr__(self):
  588. return "<%s: %s>" % (self.__class__.__name__, self)
  589. def get_source_expressions(self):
  590. return [self.condition, self.result]
  591. def set_source_expressions(self, exprs):
  592. self.condition, self.result = exprs
  593. def get_source_fields(self):
  594. # We're only interested in the fields of the result expressions.
  595. return [self.result._output_field_or_none]
  596. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  597. c = self.copy()
  598. c.is_summary = summarize
  599. if hasattr(c.condition, 'resolve_expression'):
  600. c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False)
  601. c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  602. return c
  603. def as_sql(self, compiler, connection, template=None):
  604. connection.ops.check_expression_support(self)
  605. template_params = {}
  606. sql_params = []
  607. condition_sql, condition_params = compiler.compile(self.condition)
  608. template_params['condition'] = condition_sql
  609. sql_params.extend(condition_params)
  610. result_sql, result_params = compiler.compile(self.result)
  611. template_params['result'] = result_sql
  612. sql_params.extend(result_params)
  613. template = template or self.template
  614. return template % template_params, sql_params
  615. def get_group_by_cols(self):
  616. # This is not a complete expression and cannot be used in GROUP BY.
  617. cols = []
  618. for source in self.get_source_expressions():
  619. cols.extend(source.get_group_by_cols())
  620. return cols
  621. class Case(Expression):
  622. """
  623. An SQL searched CASE expression:
  624. CASE
  625. WHEN n > 0
  626. THEN 'positive'
  627. WHEN n < 0
  628. THEN 'negative'
  629. ELSE 'zero'
  630. END
  631. """
  632. template = 'CASE %(cases)s ELSE %(default)s END'
  633. case_joiner = ' '
  634. def __init__(self, *cases, **extra):
  635. if not all(isinstance(case, When) for case in cases):
  636. raise TypeError("Positional arguments must all be When objects.")
  637. default = extra.pop('default', None)
  638. output_field = extra.pop('output_field', None)
  639. super(Case, self).__init__(output_field)
  640. self.cases = list(cases)
  641. self.default = self._parse_expressions(default)[0]
  642. def __str__(self):
  643. return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
  644. def __repr__(self):
  645. return "<%s: %s>" % (self.__class__.__name__, self)
  646. def get_source_expressions(self):
  647. return self.cases + [self.default]
  648. def set_source_expressions(self, exprs):
  649. self.cases = exprs[:-1]
  650. self.default = exprs[-1]
  651. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  652. c = self.copy()
  653. c.is_summary = summarize
  654. for pos, case in enumerate(c.cases):
  655. c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  656. c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save)
  657. return c
  658. def copy(self):
  659. c = super(Case, self).copy()
  660. c.cases = c.cases[:]
  661. return c
  662. def as_sql(self, compiler, connection, template=None, extra=None):
  663. connection.ops.check_expression_support(self)
  664. if not self.cases:
  665. return compiler.compile(self.default)
  666. template_params = dict(extra) if extra else {}
  667. case_parts = []
  668. sql_params = []
  669. for case in self.cases:
  670. case_sql, case_params = compiler.compile(case)
  671. case_parts.append(case_sql)
  672. sql_params.extend(case_params)
  673. template_params['cases'] = self.case_joiner.join(case_parts)
  674. default_sql, default_params = compiler.compile(self.default)
  675. template_params['default'] = default_sql
  676. sql_params.extend(default_params)
  677. template = template or self.template
  678. sql = template % template_params
  679. if self._output_field_or_none is not None:
  680. sql = connection.ops.unification_cast_sql(self.output_field) % sql
  681. return sql, sql_params
  682. class Date(Expression):
  683. """
  684. Add a date selection column.
  685. """
  686. def __init__(self, lookup, lookup_type):
  687. super(Date, self).__init__(output_field=fields.DateField())
  688. self.lookup = lookup
  689. self.col = None
  690. self.lookup_type = lookup_type
  691. def __repr__(self):
  692. return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)
  693. def get_source_expressions(self):
  694. return [self.col]
  695. def set_source_expressions(self, exprs):
  696. self.col, = exprs
  697. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  698. copy = self.copy()
  699. copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
  700. field = copy.col.output_field
  701. assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
  702. if settings.USE_TZ:
  703. assert not isinstance(field, fields.DateTimeField), (
  704. "%r is a DateTimeField, not a DateField." % field.name
  705. )
  706. return copy
  707. def as_sql(self, compiler, connection):
  708. sql, params = self.col.as_sql(compiler, connection)
  709. assert not(params)
  710. return connection.ops.date_trunc_sql(self.lookup_type, sql), []
  711. def copy(self):
  712. copy = super(Date, self).copy()
  713. copy.lookup = self.lookup
  714. copy.lookup_type = self.lookup_type
  715. return copy
  716. def convert_value(self, value, expression, connection, context):
  717. if isinstance(value, datetime.datetime):
  718. value = value.date()
  719. return value
  720. class DateTime(Expression):
  721. """
  722. Add a datetime selection column.
  723. """
  724. def __init__(self, lookup, lookup_type, tzinfo):
  725. super(DateTime, self).__init__(output_field=fields.DateTimeField())
  726. self.lookup = lookup
  727. self.col = None
  728. self.lookup_type = lookup_type
  729. if tzinfo is None:
  730. self.tzname = None
  731. else:
  732. self.tzname = timezone._get_timezone_name(tzinfo)
  733. self.tzinfo = tzinfo
  734. def __repr__(self):
  735. return "{}({}, {}, {})".format(
  736. self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)
  737. def get_source_expressions(self):
  738. return [self.col]
  739. def set_source_expressions(self, exprs):
  740. self.col, = exprs
  741. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  742. copy = self.copy()
  743. copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
  744. field = copy.col.output_field
  745. assert isinstance(field, fields.DateTimeField), (
  746. "%r isn't a DateTimeField." % field.name
  747. )
  748. return copy
  749. def as_sql(self, compiler, connection):
  750. sql, params = self.col.as_sql(compiler, connection)
  751. assert not(params)
  752. return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
  753. def copy(self):
  754. copy = super(DateTime, self).copy()
  755. copy.lookup = self.lookup
  756. copy.lookup_type = self.lookup_type
  757. copy.tzname = self.tzname
  758. return copy
  759. def convert_value(self, value, expression, connection, context):
  760. if settings.USE_TZ:
  761. if value is None:
  762. raise ValueError(
  763. "Database returned an invalid value in QuerySet.datetimes(). "
  764. "Are time zone definitions for your database and pytz installed?"
  765. )
  766. value = value.replace(tzinfo=None)
  767. value = timezone.make_aware(value, self.tzinfo)
  768. return value
  769. class OrderBy(BaseExpression):
  770. template = '%(expression)s %(ordering)s'
  771. def __init__(self, expression, descending=False):
  772. self.descending = descending
  773. if not hasattr(expression, 'resolve_expression'):
  774. raise ValueError('expression must be an expression type')
  775. self.expression = expression
  776. def __repr__(self):
  777. return "{}({}, descending={})".format(
  778. self.__class__.__name__, self.expression, self.descending)
  779. def set_source_expressions(self, exprs):
  780. self.expression = exprs[0]
  781. def get_source_expressions(self):
  782. return [self.expression]
  783. def as_sql(self, compiler, connection):
  784. connection.ops.check_expression_support(self)
  785. expression_sql, params = compiler.compile(self.expression)
  786. placeholders = {'expression': expression_sql}
  787. placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
  788. return (self.template % placeholders).rstrip(), params
  789. def get_group_by_cols(self):
  790. cols = []
  791. for source in self.get_source_expressions():
  792. cols.extend(source.get_group_by_cols())
  793. return cols
  794. def reverse_ordering(self):
  795. self.descending = not self.descending
  796. return self
  797. def asc(self):
  798. self.descending = False
  799. def desc(self):
  800. self.descending = True