You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

555 lines
19 KiB

  1. from copy import copy
  2. from django.conf import settings
  3. from django.db.models.expressions import Func, Value
  4. from django.db.models.fields import (
  5. DateField, DateTimeField, Field, IntegerField, TimeField,
  6. )
  7. from django.db.models.query_utils import RegisterLookupMixin
  8. from django.utils import timezone
  9. from django.utils.functional import cached_property
  10. from django.utils.six.moves import range
  11. class Lookup(object):
  12. lookup_name = None
  13. def __init__(self, lhs, rhs):
  14. self.lhs, self.rhs = lhs, rhs
  15. self.rhs = self.get_prep_lookup()
  16. if hasattr(self.lhs, 'get_bilateral_transforms'):
  17. bilateral_transforms = self.lhs.get_bilateral_transforms()
  18. else:
  19. bilateral_transforms = []
  20. if bilateral_transforms:
  21. # We should warn the user as soon as possible if he is trying to apply
  22. # a bilateral transformation on a nested QuerySet: that won't work.
  23. # We need to import QuerySet here so as to avoid circular
  24. from django.db.models.query import QuerySet
  25. if isinstance(rhs, QuerySet):
  26. raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
  27. self.bilateral_transforms = bilateral_transforms
  28. def apply_bilateral_transforms(self, value):
  29. for transform in self.bilateral_transforms:
  30. value = transform(value)
  31. return value
  32. def batch_process_rhs(self, compiler, connection, rhs=None):
  33. if rhs is None:
  34. rhs = self.rhs
  35. if self.bilateral_transforms:
  36. sqls, sqls_params = [], []
  37. for p in rhs:
  38. value = Value(p, output_field=self.lhs.output_field)
  39. value = self.apply_bilateral_transforms(value)
  40. value = value.resolve_expression(compiler.query)
  41. sql, sql_params = compiler.compile(value)
  42. sqls.append(sql)
  43. sqls_params.extend(sql_params)
  44. else:
  45. params = self.lhs.output_field.get_db_prep_lookup(
  46. self.lookup_name, rhs, connection, prepared=True)
  47. sqls, sqls_params = ['%s'] * len(params), params
  48. return sqls, sqls_params
  49. def get_prep_lookup(self):
  50. return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
  51. def get_db_prep_lookup(self, value, connection):
  52. return (
  53. '%s', self.lhs.output_field.get_db_prep_lookup(
  54. self.lookup_name, value, connection, prepared=True))
  55. def process_lhs(self, compiler, connection, lhs=None):
  56. lhs = lhs or self.lhs
  57. return compiler.compile(lhs)
  58. def process_rhs(self, compiler, connection):
  59. value = self.rhs
  60. if self.bilateral_transforms:
  61. if self.rhs_is_direct_value():
  62. # Do not call get_db_prep_lookup here as the value will be
  63. # transformed before being used for lookup
  64. value = Value(value, output_field=self.lhs.output_field)
  65. value = self.apply_bilateral_transforms(value)
  66. value = value.resolve_expression(compiler.query)
  67. # Due to historical reasons there are a couple of different
  68. # ways to produce sql here. get_compiler is likely a Query
  69. # instance, _as_sql QuerySet and as_sql just something with
  70. # as_sql. Finally the value can of course be just plain
  71. # Python value.
  72. if hasattr(value, 'get_compiler'):
  73. value = value.get_compiler(connection=connection)
  74. if hasattr(value, 'as_sql'):
  75. sql, params = compiler.compile(value)
  76. return '(' + sql + ')', params
  77. if hasattr(value, '_as_sql'):
  78. sql, params = value._as_sql(connection=connection)
  79. return '(' + sql + ')', params
  80. else:
  81. return self.get_db_prep_lookup(value, connection)
  82. def rhs_is_direct_value(self):
  83. return not(
  84. hasattr(self.rhs, 'as_sql') or
  85. hasattr(self.rhs, '_as_sql') or
  86. hasattr(self.rhs, 'get_compiler'))
  87. def relabeled_clone(self, relabels):
  88. new = copy(self)
  89. new.lhs = new.lhs.relabeled_clone(relabels)
  90. if hasattr(new.rhs, 'relabeled_clone'):
  91. new.rhs = new.rhs.relabeled_clone(relabels)
  92. return new
  93. def get_group_by_cols(self):
  94. cols = self.lhs.get_group_by_cols()
  95. if hasattr(self.rhs, 'get_group_by_cols'):
  96. cols.extend(self.rhs.get_group_by_cols())
  97. return cols
  98. def as_sql(self, compiler, connection):
  99. raise NotImplementedError
  100. @cached_property
  101. def contains_aggregate(self):
  102. return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
  103. class Transform(RegisterLookupMixin, Func):
  104. """
  105. RegisterLookupMixin() is first so that get_lookup() and get_transform()
  106. first examine self and then check output_field.
  107. """
  108. bilateral = False
  109. def __init__(self, expression, **extra):
  110. # Restrict Transform to allow only a single expression.
  111. super(Transform, self).__init__(expression, **extra)
  112. @property
  113. def lhs(self):
  114. return self.get_source_expressions()[0]
  115. def get_bilateral_transforms(self):
  116. if hasattr(self.lhs, 'get_bilateral_transforms'):
  117. bilateral_transforms = self.lhs.get_bilateral_transforms()
  118. else:
  119. bilateral_transforms = []
  120. if self.bilateral:
  121. bilateral_transforms.append(self.__class__)
  122. return bilateral_transforms
  123. class BuiltinLookup(Lookup):
  124. def process_lhs(self, compiler, connection, lhs=None):
  125. lhs_sql, params = super(BuiltinLookup, self).process_lhs(
  126. compiler, connection, lhs)
  127. field_internal_type = self.lhs.output_field.get_internal_type()
  128. db_type = self.lhs.output_field.db_type(connection=connection)
  129. lhs_sql = connection.ops.field_cast_sql(
  130. db_type, field_internal_type) % lhs_sql
  131. lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
  132. return lhs_sql, list(params)
  133. def as_sql(self, compiler, connection):
  134. lhs_sql, params = self.process_lhs(compiler, connection)
  135. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  136. params.extend(rhs_params)
  137. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  138. return '%s %s' % (lhs_sql, rhs_sql), params
  139. def get_rhs_op(self, connection, rhs):
  140. return connection.operators[self.lookup_name] % rhs
  141. class Exact(BuiltinLookup):
  142. lookup_name = 'exact'
  143. Field.register_lookup(Exact)
  144. class IExact(BuiltinLookup):
  145. lookup_name = 'iexact'
  146. def process_rhs(self, qn, connection):
  147. rhs, params = super(IExact, self).process_rhs(qn, connection)
  148. if params:
  149. params[0] = connection.ops.prep_for_iexact_query(params[0])
  150. return rhs, params
  151. Field.register_lookup(IExact)
  152. class GreaterThan(BuiltinLookup):
  153. lookup_name = 'gt'
  154. Field.register_lookup(GreaterThan)
  155. class GreaterThanOrEqual(BuiltinLookup):
  156. lookup_name = 'gte'
  157. Field.register_lookup(GreaterThanOrEqual)
  158. class LessThan(BuiltinLookup):
  159. lookup_name = 'lt'
  160. Field.register_lookup(LessThan)
  161. class LessThanOrEqual(BuiltinLookup):
  162. lookup_name = 'lte'
  163. Field.register_lookup(LessThanOrEqual)
  164. class In(BuiltinLookup):
  165. lookup_name = 'in'
  166. def process_rhs(self, compiler, connection):
  167. if self.rhs_is_direct_value():
  168. # rhs should be an iterable, we use batch_process_rhs
  169. # to prepare/transform those values
  170. rhs = list(self.rhs)
  171. if not rhs:
  172. from django.db.models.sql.datastructures import EmptyResultSet
  173. raise EmptyResultSet
  174. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
  175. placeholder = '(' + ', '.join(sqls) + ')'
  176. return (placeholder, sqls_params)
  177. else:
  178. return super(In, self).process_rhs(compiler, connection)
  179. def get_rhs_op(self, connection, rhs):
  180. return 'IN %s' % rhs
  181. def as_sql(self, compiler, connection):
  182. max_in_list_size = connection.ops.max_in_list_size()
  183. if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
  184. return self.split_parameter_list_as_sql(compiler, connection)
  185. return super(In, self).as_sql(compiler, connection)
  186. def split_parameter_list_as_sql(self, compiler, connection):
  187. # This is a special case for databases which limit the number of
  188. # elements which can appear in an 'IN' clause.
  189. max_in_list_size = connection.ops.max_in_list_size()
  190. lhs, lhs_params = self.process_lhs(compiler, connection)
  191. rhs, rhs_params = self.batch_process_rhs(compiler, connection)
  192. in_clause_elements = ['(']
  193. params = []
  194. for offset in range(0, len(rhs_params), max_in_list_size):
  195. if offset > 0:
  196. in_clause_elements.append(' OR ')
  197. in_clause_elements.append('%s IN (' % lhs)
  198. params.extend(lhs_params)
  199. sqls = rhs[offset: offset + max_in_list_size]
  200. sqls_params = rhs_params[offset: offset + max_in_list_size]
  201. param_group = ', '.join(sqls)
  202. in_clause_elements.append(param_group)
  203. in_clause_elements.append(')')
  204. params.extend(sqls_params)
  205. in_clause_elements.append(')')
  206. return ''.join(in_clause_elements), params
  207. Field.register_lookup(In)
  208. class PatternLookup(BuiltinLookup):
  209. def get_rhs_op(self, connection, rhs):
  210. # Assume we are in startswith. We need to produce SQL like:
  211. # col LIKE %s, ['thevalue%']
  212. # For python values we can (and should) do that directly in Python,
  213. # but if the value is for example reference to other column, then
  214. # we need to add the % pattern match to the lookup by something like
  215. # col LIKE othercol || '%%'
  216. # So, for Python values we don't need any special pattern, but for
  217. # SQL reference values or SQL transformations we need the correct
  218. # pattern added.
  219. if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
  220. or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
  221. pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
  222. return pattern.format(rhs)
  223. else:
  224. return super(PatternLookup, self).get_rhs_op(connection, rhs)
  225. class Contains(PatternLookup):
  226. lookup_name = 'contains'
  227. def process_rhs(self, qn, connection):
  228. rhs, params = super(Contains, self).process_rhs(qn, connection)
  229. if params and not self.bilateral_transforms:
  230. params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
  231. return rhs, params
  232. Field.register_lookup(Contains)
  233. class IContains(Contains):
  234. lookup_name = 'icontains'
  235. Field.register_lookup(IContains)
  236. class StartsWith(PatternLookup):
  237. lookup_name = 'startswith'
  238. def process_rhs(self, qn, connection):
  239. rhs, params = super(StartsWith, self).process_rhs(qn, connection)
  240. if params and not self.bilateral_transforms:
  241. params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
  242. return rhs, params
  243. Field.register_lookup(StartsWith)
  244. class IStartsWith(PatternLookup):
  245. lookup_name = 'istartswith'
  246. def process_rhs(self, qn, connection):
  247. rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
  248. if params and not self.bilateral_transforms:
  249. params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
  250. return rhs, params
  251. Field.register_lookup(IStartsWith)
  252. class EndsWith(PatternLookup):
  253. lookup_name = 'endswith'
  254. def process_rhs(self, qn, connection):
  255. rhs, params = super(EndsWith, self).process_rhs(qn, connection)
  256. if params and not self.bilateral_transforms:
  257. params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
  258. return rhs, params
  259. Field.register_lookup(EndsWith)
  260. class IEndsWith(PatternLookup):
  261. lookup_name = 'iendswith'
  262. def process_rhs(self, qn, connection):
  263. rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
  264. if params and not self.bilateral_transforms:
  265. params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
  266. return rhs, params
  267. Field.register_lookup(IEndsWith)
  268. class Between(BuiltinLookup):
  269. def get_rhs_op(self, connection, rhs):
  270. return "BETWEEN %s AND %s" % (rhs, rhs)
  271. class Range(BuiltinLookup):
  272. lookup_name = 'range'
  273. def get_rhs_op(self, connection, rhs):
  274. return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
  275. def process_rhs(self, compiler, connection):
  276. if self.rhs_is_direct_value():
  277. # rhs should be an iterable of 2 values, we use batch_process_rhs
  278. # to prepare/transform those values
  279. return self.batch_process_rhs(compiler, connection)
  280. else:
  281. return super(Range, self).process_rhs(compiler, connection)
  282. Field.register_lookup(Range)
  283. class IsNull(BuiltinLookup):
  284. lookup_name = 'isnull'
  285. def as_sql(self, compiler, connection):
  286. sql, params = compiler.compile(self.lhs)
  287. if self.rhs:
  288. return "%s IS NULL" % sql, params
  289. else:
  290. return "%s IS NOT NULL" % sql, params
  291. Field.register_lookup(IsNull)
  292. class Search(BuiltinLookup):
  293. lookup_name = 'search'
  294. def as_sql(self, compiler, connection):
  295. lhs, lhs_params = self.process_lhs(compiler, connection)
  296. rhs, rhs_params = self.process_rhs(compiler, connection)
  297. sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
  298. return sql_template, lhs_params + rhs_params
  299. Field.register_lookup(Search)
  300. class Regex(BuiltinLookup):
  301. lookup_name = 'regex'
  302. def as_sql(self, compiler, connection):
  303. if self.lookup_name in connection.operators:
  304. return super(Regex, self).as_sql(compiler, connection)
  305. else:
  306. lhs, lhs_params = self.process_lhs(compiler, connection)
  307. rhs, rhs_params = self.process_rhs(compiler, connection)
  308. sql_template = connection.ops.regex_lookup(self.lookup_name)
  309. return sql_template % (lhs, rhs), lhs_params + rhs_params
  310. Field.register_lookup(Regex)
  311. class IRegex(Regex):
  312. lookup_name = 'iregex'
  313. Field.register_lookup(IRegex)
  314. class DateTimeDateTransform(Transform):
  315. lookup_name = 'date'
  316. @cached_property
  317. def output_field(self):
  318. return DateField()
  319. def as_sql(self, compiler, connection):
  320. lhs, lhs_params = compiler.compile(self.lhs)
  321. tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
  322. sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
  323. lhs_params.extend(tz_params)
  324. return sql, lhs_params
  325. class DateTransform(Transform):
  326. def as_sql(self, compiler, connection):
  327. sql, params = compiler.compile(self.lhs)
  328. lhs_output_field = self.lhs.output_field
  329. if isinstance(lhs_output_field, DateTimeField):
  330. tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
  331. sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
  332. params.extend(tz_params)
  333. elif isinstance(lhs_output_field, DateField):
  334. sql = connection.ops.date_extract_sql(self.lookup_name, sql)
  335. elif isinstance(lhs_output_field, TimeField):
  336. sql = connection.ops.time_extract_sql(self.lookup_name, sql)
  337. else:
  338. raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
  339. return sql, params
  340. @cached_property
  341. def output_field(self):
  342. return IntegerField()
  343. class YearTransform(DateTransform):
  344. lookup_name = 'year'
  345. class YearLookup(Lookup):
  346. def year_lookup_bounds(self, connection, year):
  347. output_field = self.lhs.lhs.output_field
  348. if isinstance(output_field, DateTimeField):
  349. bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
  350. else:
  351. bounds = connection.ops.year_lookup_bounds_for_date_field(year)
  352. return bounds
  353. @YearTransform.register_lookup
  354. class YearExact(YearLookup):
  355. lookup_name = 'exact'
  356. def as_sql(self, compiler, connection):
  357. # We will need to skip the extract part and instead go
  358. # directly with the originating field, that is self.lhs.lhs.
  359. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  360. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  361. bounds = self.year_lookup_bounds(connection, rhs_params[0])
  362. params.extend(bounds)
  363. return '%s BETWEEN %%s AND %%s' % lhs_sql, params
  364. class YearComparisonLookup(YearLookup):
  365. def as_sql(self, compiler, connection):
  366. # We will need to skip the extract part and instead go
  367. # directly with the originating field, that is self.lhs.lhs.
  368. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
  369. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  370. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  371. start, finish = self.year_lookup_bounds(connection, rhs_params[0])
  372. params.append(self.get_bound(start, finish))
  373. return '%s %s' % (lhs_sql, rhs_sql), params
  374. def get_rhs_op(self, connection, rhs):
  375. return connection.operators[self.lookup_name] % rhs
  376. def get_bound(self):
  377. raise NotImplementedError(
  378. 'subclasses of YearComparisonLookup must provide a get_bound() method'
  379. )
  380. @YearTransform.register_lookup
  381. class YearGt(YearComparisonLookup):
  382. lookup_name = 'gt'
  383. def get_bound(self, start, finish):
  384. return finish
  385. @YearTransform.register_lookup
  386. class YearGte(YearComparisonLookup):
  387. lookup_name = 'gte'
  388. def get_bound(self, start, finish):
  389. return start
  390. @YearTransform.register_lookup
  391. class YearLt(YearComparisonLookup):
  392. lookup_name = 'lt'
  393. def get_bound(self, start, finish):
  394. return start
  395. @YearTransform.register_lookup
  396. class YearLte(YearComparisonLookup):
  397. lookup_name = 'lte'
  398. def get_bound(self, start, finish):
  399. return finish
  400. class MonthTransform(DateTransform):
  401. lookup_name = 'month'
  402. class DayTransform(DateTransform):
  403. lookup_name = 'day'
  404. class WeekDayTransform(DateTransform):
  405. lookup_name = 'week_day'
  406. class HourTransform(DateTransform):
  407. lookup_name = 'hour'
  408. class MinuteTransform(DateTransform):
  409. lookup_name = 'minute'
  410. class SecondTransform(DateTransform):
  411. lookup_name = 'second'
  412. DateField.register_lookup(YearTransform)
  413. DateField.register_lookup(MonthTransform)
  414. DateField.register_lookup(DayTransform)
  415. DateField.register_lookup(WeekDayTransform)
  416. TimeField.register_lookup(HourTransform)
  417. TimeField.register_lookup(MinuteTransform)
  418. TimeField.register_lookup(SecondTransform)
  419. DateTimeField.register_lookup(DateTimeDateTransform)
  420. DateTimeField.register_lookup(YearTransform)
  421. DateTimeField.register_lookup(MonthTransform)
  422. DateTimeField.register_lookup(DayTransform)
  423. DateTimeField.register_lookup(WeekDayTransform)
  424. DateTimeField.register_lookup(HourTransform)
  425. DateTimeField.register_lookup(MinuteTransform)
  426. DateTimeField.register_lookup(SecondTransform)