You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lookups.py 19 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. from copy import copy
  2. from django.conf import settings
  3. from django.db.models.expressions import Func, Value
  4. from django.db.models.fields import (
  5. DateField, DateTimeField, Field, IntegerField, TimeField,
  6. )
  7. from django.db.models.query_utils import RegisterLookupMixin
  8. from django.utils import timezone
  9. from django.utils.functional import cached_property
  10. from django.utils.six.moves import range
  11. class Lookup(object):
  12. lookup_name = None
  13. def __init__(self, lhs, rhs):
  14. self.lhs, self.rhs = lhs, rhs
  15. self.rhs = self.get_prep_lookup()
  16. if hasattr(self.lhs, 'get_bilateral_transforms'):
  17. bilateral_transforms = self.lhs.get_bilateral_transforms()
  18. else:
  19. bilateral_transforms = []
  20. if bilateral_transforms:
  21. # We should warn the user as soon as possible if he is trying to apply
  22. # a bilateral transformation on a nested QuerySet: that won't work.
  23. # We need to import QuerySet here so as to avoid circular
  24. from django.db.models.query import QuerySet
  25. if isinstance(rhs, QuerySet):
  26. raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
  27. self.bilateral_transforms = bilateral_transforms
  28. def apply_bilateral_transforms(self, value):
  29. for transform in self.bilateral_transforms:
  30. value = transform(value)
  31. return value
  32. def batch_process_rhs(self, compiler, connection, rhs=None):
  33. if rhs is None:
  34. rhs = self.rhs
  35. if self.bilateral_transforms:
  36. sqls, sqls_params = [], []
  37. for p in rhs:
  38. value = Value(p, output_field=self.lhs.output_field)
  39. value = self.apply_bilateral_transforms(value)
  40. value = value.resolve_expression(compiler.query)
  41. sql, sql_params = compiler.compile(value)
  42. sqls.append(sql)
  43. sqls_params.extend(sql_params)
  44. else:
  45. params = self.lhs.output_field.get_db_prep_lookup(
  46. self.lookup_name, rhs, connection, prepared=True)
  47. sqls, sqls_params = ['%s'] * len(params), params
  48. return sqls, sqls_params
  49. def get_prep_lookup(self):
  50. return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
  51. def get_db_prep_lookup(self, value, connection):
  52. return (
  53. '%s', self.lhs.output_field.get_db_prep_lookup(
  54. self.lookup_name, value, connection, prepared=True))
  55. def process_lhs(self, compiler, connection, lhs=None):
  56. lhs = lhs or self.lhs
  57. return compiler.compile(lhs)
  58. def process_rhs(self, compiler, connection):
  59. value = self.rhs
  60. if self.bilateral_transforms:
  61. if self.rhs_is_direct_value():
  62. # Do not call get_db_prep_lookup here as the value will be
  63. # transformed before being used for lookup
  64. value = Value(value, output_field=self.lhs.output_field)
  65. value = self.apply_bilateral_transforms(value)
  66. value = value.resolve_expression(compiler.query)
  67. # Due to historical reasons there are a couple of different
  68. # ways to produce sql here. get_compiler is likely a Query
  69. # instance, _as_sql QuerySet and as_sql just something with
  70. # as_sql. Finally the value can of course be just plain
  71. # Python value.
  72. if hasattr(value, 'get_compiler'):
  73. value = value.get_compiler(connection=connection)
  74. if hasattr(value, 'as_sql'):
  75. sql, params = compiler.compile(value)
  76. return '(' + sql + ')', params
  77. if hasattr(value, '_as_sql'):
  78. sql, params = value._as_sql(connection=connection)
  79. return '(' + sql + ')', params
  80. else:
  81. return self.get_db_prep_lookup(value, connection)
  82. def rhs_is_direct_value(self):
  83. return not(
  84. hasattr(self.rhs, 'as_sql') or
  85. hasattr(self.rhs, '_as_sql') or
  86. hasattr(self.rhs, 'get_compiler'))
  87. def relabeled_clone(self, relabels):
  88. new = copy(self)
  89. new.lhs = new.lhs.relabeled_clone(relabels)
  90. if hasattr(new.rhs, 'relabeled_clone'):
  91. new.rhs = new.rhs.relabeled_clone(relabels)
  92. return new
  93. def get_group_by_cols(self):
  94. cols = self.lhs.get_group_by_cols()
  95. if hasattr(self.rhs, 'get_group_by_cols'):
  96. cols.extend(self.rhs.get_group_by_cols())
  97. return cols
  98. def as_sql(self, compiler, connection):
  99. raise NotImplementedError
  100. @cached_property
  101. def contains_aggregate(self):
  102. return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
  103. class Transform(RegisterLookupMixin, Func):
  104. """
  105. RegisterLookupMixin() is first so that get_lookup() and get_transform()
  106. first examine self and then check output_field.
  107. """
  108. bilateral = False
  109. def __init__(self, expression, **extra):
  110. # Restrict Transform to allow only a single expression.
  111. super(Transform, self).__init__(expression, **extra)
  112. @property
  113. def lhs(self):
  114. return self.get_source_expressions()[0]
  115. def get_bilateral_transforms(self):
  116. if hasattr(self.lhs, 'get_bilateral_transforms'):
  117. bilateral_transforms = self.lhs.get_bilateral_transforms()
  118. else:
  119. bilateral_transforms = []
  120. if self.bilateral:
  121. bilateral_transforms.append(self.__class__)
  122. return bilateral_transforms
  123. class BuiltinLookup(Lookup):
  124. def process_lhs(self, compiler, connection, lhs=None):
  125. lhs_sql, params = super(BuiltinLookup, self).process_lhs(
  126. compiler, connection, lhs)
  127. field_internal_type = self.lhs.output_field.get_internal_type()
  128. db_type = self.lhs.output_field.db_type(connection=connection)
  129. lhs_sql = connection.ops.field_cast_sql(
  130. db_type, field_internal_type) % lhs_sql
  131. lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
  132. return lhs_sql, list(params)
  133. def as_sql(self, compiler, connection):
  134. lhs_sql, params = self.process_lhs(compiler, connection)
  135. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  136. params.extend(rhs_params)
  137. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  138. return '%s %s' % (lhs_sql, rhs_sql), params
  139. def get_rhs_op(self, connection, rhs):
  140. return connection.operators[self.lookup_name] % rhs
  141. class Exact(BuiltinLookup):
  142. lookup_name = 'exact'
  143. Field.register_lookup(Exact)
  144. class IExact(BuiltinLookup):
  145. lookup_name = 'iexact'
  146. def process_rhs(self, qn, connection):
  147. rhs, params = super(IExact, self).process_rhs(qn, connection)
  148. if params:
  149. params[0] = connection.ops.prep_for_iexact_query(params[0])
  150. return rhs, params
  151. Field.register_lookup(IExact)
  152. class GreaterThan(BuiltinLookup):
  153. lookup_name = 'gt'
  154. Field.register_lookup(GreaterThan)
  155. class GreaterThanOrEqual(BuiltinLookup):
  156. lookup_name = 'gte'
  157. Field.register_lookup(GreaterThanOrEqual)
  158. class LessThan(BuiltinLookup):
  159. lookup_name = 'lt'
  160. Field.register_lookup(LessThan)
  161. class LessThanOrEqual(BuiltinLookup):
  162. lookup_name = 'lte'
  163. Field.register_lookup(LessThanOrEqual)
  164. class In(BuiltinLookup):
  165. lookup_name = 'in'
  166. def process_rhs(self, compiler, connection):
  167. if self.rhs_is_direct_value():
  168. # rhs should be an iterable, we use batch_process_rhs
  169. # to prepare/transform those values
  170. rhs = list(self.rhs)
  171. if not rhs:
  172. from django.db.models.sql.datastructures import EmptyResultSet
  173. raise EmptyResultSet
  174. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
  175. placeholder = '(' + ', '.join(sqls) + ')'
  176. return (placeholder, sqls_params)
  177. else:
  178. return super(In, self).process_rhs(compiler, connection)
  179. def get_rhs_op(self, connection, rhs):
  180. return 'IN %s' % rhs
  181. def as_sql(self, compiler, connection):
  182. max_in_list_size = connection.ops.max_in_list_size()
  183. if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
  184. return self.split_parameter_list_as_sql(compiler, connection)
  185. return super(In, self).as_sql(compiler, connection)
  186. def split_parameter_list_as_sql(self, compiler, connection):
  187. # This is a special case for databases which limit the number of
  188. # elements which can appear in an 'IN' clause.
  189. max_in_list_size = connection.ops.max_in_list_size()
  190. lhs, lhs_params = self.process_lhs(compiler, connection)
  191. rhs, rhs_params = self.batch_process_rhs(compiler, connection)
  192. in_clause_elements = ['(']
  193. params = []
  194. for offset in range(0, len(rhs_params), max_in_list_size):
  195. if offset > 0:
  196. in_clause_elements.append(' OR ')
  197. in_clause_elements.append('%s IN (' % lhs)
  198. params.extend(lhs_params)
  199. sqls = rhs[offset: offset + max_in_list_size]
  200. sqls_params = rhs_params[offset: offset + max_in_list_size]
  201. param_group = ', '.join(sqls)
  202. in_clause_elements.append(param_group)
  203. in_clause_elements.append(')')
  204. params.extend(sqls_params)
  205. in_clause_elements.append(')')
  206. return ''.join(in_clause_elements), params
  207. Field.register_lookup(In)
  208. class PatternLookup(BuiltinLookup):
  209. def get_rhs_op(self, connection, rhs):
  210. # Assume we are in startswith. We need to produce SQL like:
  211. # col LIKE %s, ['thevalue%']
  212. # For python values we can (and should) do that directly in Python,
  213. # but if the value is for example reference to other column, then
  214. # we need to add the % pattern match to the lookup by something like
  215. # col LIKE othercol || '%%'
  216. # So, for Python values we don't need any special pattern, but for
  217. # SQL reference values or SQL transformations we need the correct
  218. # pattern added.
  219. if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
  220. or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
  221. pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
  222. return pattern.format(rhs)
  223. else:
  224. return super(PatternLookup, self).get_rhs_op(connection, rhs)
  225. class Contains(PatternLookup):
  226. lookup_name = 'contains'
  227. def process_rhs(self, qn, connection):
  228. rhs, params = super(Contains, self).process_rhs(qn, connection)
  229. if params and not self.bilateral_transforms:
  230. params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
  231. return rhs, params
  232. Field.register_lookup(Contains)
  233. class IContains(Contains):
  234. lookup_name = 'icontains'
  235. Field.register_lookup(IContains)
  236. class StartsWith(PatternLookup):
  237. lookup_name = 'startswith'
  238. def process_rhs(self, qn, connection):
  239. rhs, params = super(StartsWith, self).process_rhs(qn, connection)
  240. if params and not self.bilateral_transforms:
  241. params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
  242. return rhs, params
  243. Field.register_lookup(StartsWith)
  244. class IStartsWith(PatternLookup):
  245. lookup_name = 'istartswith'
  246. def process_rhs(self, qn, connection):
  247. rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
  248. if params and not self.bilateral_transforms:
  249. params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
  250. return rhs, params
  251. Field.register_lookup(IStartsWith)
  252. class EndsWith(PatternLookup):
  253. lookup_name = 'endswith'
  254. def process_rhs(self, qn, connection):
  255. rhs, params = super(EndsWith, self).process_rhs(qn, connection)
  256. if params and not self.bilateral_transforms:
  257. params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
  258. return rhs, params
  259. Field.register_lookup(EndsWith)
  260. class IEndsWith(PatternLookup):
  261. lookup_name = 'iendswith'
  262. def process_rhs(self, qn, connection):
  263. rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
  264. if params and not self.bilateral_transforms:
  265. params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
  266. return rhs, params
  267. Field.register_lookup(IEndsWith)
  268. class Between(BuiltinLookup):
  269. def get_rhs_op(self, connection, rhs):
  270. return "BETWEEN %s AND %s" % (rhs, rhs)
  271. class Range(BuiltinLookup):
  272. lookup_name = 'range'
  273. def get_rhs_op(self, connection, rhs):
  274. return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
  275. def process_rhs(self, compiler, connection):
  276. if self.rhs_is_direct_value():
  277. # rhs should be an iterable of 2 values, we use batch_process_rhs
  278. # to prepare/transform those values
  279. return self.batch_process_rhs(compiler, connection)
  280. else:
  281. return super(Range, self).process_rhs(compiler, connection)
  282. Field.register_lookup(Range)
  283. class IsNull(BuiltinLookup):
  284. lookup_name = 'isnull'
  285. def as_sql(self, compiler, connection):
  286. sql, params = compiler.compile(self.lhs)
  287. if self.rhs:
  288. return "%s IS NULL" % sql, params
  289. else:
  290. return "%s IS NOT NULL" % sql, params
  291. Field.register_lookup(IsNull)
  292. class Search(BuiltinLookup):
  293. lookup_name = 'search'
  294. def as_sql(self, compiler, connection):
  295. lhs, lhs_params = self.process_lhs(compiler, connection)
  296. rhs, rhs_params = self.process_rhs(compiler, connection)
  297. sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
  298. return sql_template, lhs_params + rhs_params
  299. Field.register_lookup(Search)
  300. class Regex(BuiltinLookup):
  301. lookup_name = 'regex'
  302. def as_sql(self, compiler, connection):
  303. if self.lookup_name in connection.operators:
  304. return super(Regex, self).as_sql(compiler, connection)
  305. else:
  306. lhs, lhs_params = self.process_lhs(compiler, connection)
  307. rhs, rhs_params = self.process_rhs(compiler, connection)
  308. sql_template = connection.ops.regex_lookup(self.lookup_name)
  309. return sql_template % (lhs, rhs), lhs_params + rhs_params
  310. Field.register_lookup(Regex)
  311. class IRegex(Regex):
  312. lookup_name = 'iregex'
  313. Field.register_lookup(IRegex)
  314. class DateTimeDateTransform(Transform):
  315. lookup_name = 'date'
  316. @cached_property
  317. def output_field(self):
  318. return DateField()
  319. def as_sql(self, compiler, connection):
  320. lhs, lhs_params = compiler.compile(self.lhs)
  321. tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
  322. sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
  323. lhs_params.extend(tz_params)
  324. return sql, lhs_params
  325. class DateTransform(Transform):
  326. def as_sql(self, compiler, connection):
  327. sql, params = compiler.compile(self.lhs)
  328. lhs_output_field = self.lhs.output_field
  329. if isinstance(lhs_output_field, DateTimeField):
  330. tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
  331. sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
  332. params.extend(tz_params)
  333. elif isinstance(lhs_output_field, DateField):
  334. sql = connection.ops.date_extract_sql(self.lookup_name, sql)
  335. elif isinstance(lhs_output_field, TimeField):
  336. sql = connection.ops.time_extract_sql(self.lookup_name, sql)
  337. else:
  338. raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
  339. return sql, params
  340. @cached_property
  341. def output_field(self):
  342. return IntegerField()
  343. class YearTransform(DateTransform):
  344. lookup_name = 'year'
  345. class YearLookup(Lookup):
  346. def year_lookup_bounds(self, connection, year):
  347. output_field = self.lhs.lhs.output_field
  348. if isinstance(output_field, DateTimeField):
  349. bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
  350. else:
  351. bounds = connection.ops.year_lookup_bounds_for_date_field(year)
  352. return bounds
  353. @YearTransform.register_lookup
  354. class YearExact(YearLookup):
  355. lookup_name = 'exact'
  356. def as_sql(self, compiler, connection):
  357. # We will need to skip the extract part and instead go
  358. # directly with the originating field, that is self.lhs.lhs.
  359. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  360. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  361. bounds = self.year_lookup_bounds(connection, rhs_params[0])
  362. params.extend(bounds)
  363. return '%s BETWEEN %%s AND %%s' % lhs_sql, params
  364. class YearComparisonLookup(YearLookup):
  365. def as_sql(self, compiler, connection):
  366. # We will need to skip the extract part and instead go
  367. # directly with the originating field, that is self.lhs.lhs.
  368. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  369. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  370. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  371. start, finish = self.year_lookup_bounds(connection, rhs_params[0])
  372. params.append(self.get_bound(start, finish))
  373. return '%s %s' % (lhs_sql, rhs_sql), params
  374. def get_rhs_op(self, connection, rhs):
  375. return connection.operators[self.lookup_name] % rhs
  376. def get_bound(self):
  377. raise NotImplementedError(
  378. 'subclasses of YearComparisonLookup must provide a get_bound() method'
  379. )
  380. @YearTransform.register_lookup
  381. class YearGt(YearComparisonLookup):
  382. lookup_name = 'gt'
  383. def get_bound(self, start, finish):
  384. return finish
  385. @YearTransform.register_lookup
  386. class YearGte(YearComparisonLookup):
  387. lookup_name = 'gte'
  388. def get_bound(self, start, finish):
  389. return start
  390. @YearTransform.register_lookup
  391. class YearLt(YearComparisonLookup):
  392. lookup_name = 'lt'
  393. def get_bound(self, start, finish):
  394. return start
  395. @YearTransform.register_lookup
  396. class YearLte(YearComparisonLookup):
  397. lookup_name = 'lte'
  398. def get_bound(self, start, finish):
  399. return finish
  400. class MonthTransform(DateTransform):
  401. lookup_name = 'month'
  402. class DayTransform(DateTransform):
  403. lookup_name = 'day'
  404. class WeekDayTransform(DateTransform):
  405. lookup_name = 'week_day'
  406. class HourTransform(DateTransform):
  407. lookup_name = 'hour'
  408. class MinuteTransform(DateTransform):
  409. lookup_name = 'minute'
  410. class SecondTransform(DateTransform):
  411. lookup_name = 'second'
  412. DateField.register_lookup(YearTransform)
  413. DateField.register_lookup(MonthTransform)
  414. DateField.register_lookup(DayTransform)
  415. DateField.register_lookup(WeekDayTransform)
  416. TimeField.register_lookup(HourTransform)
  417. TimeField.register_lookup(MinuteTransform)
  418. TimeField.register_lookup(SecondTransform)
  419. DateTimeField.register_lookup(DateTimeDateTransform)
  420. DateTimeField.register_lookup(YearTransform)
  421. DateTimeField.register_lookup(MonthTransform)
  422. DateTimeField.register_lookup(DayTransform)
  423. DateTimeField.register_lookup(WeekDayTransform)
  424. DateTimeField.register_lookup(HourTransform)
  425. DateTimeField.register_lookup(MinuteTransform)
  426. DateTimeField.register_lookup(SecondTransform)