You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

204 lines
5.6 KiB

  1. import json
  2. from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
  3. from django.contrib.postgres import forms, lookups
  4. from django.db import models
  5. from django.utils import six
  6. from .utils import AttributeSetter
  7. __all__ = [
  8. 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
  9. 'FloatRangeField', 'DateTimeRangeField', 'DateRangeField',
  10. ]
  11. class RangeField(models.Field):
  12. empty_strings_allowed = False
  13. def get_prep_value(self, value):
  14. if value is None:
  15. return None
  16. elif isinstance(value, Range):
  17. return value
  18. elif isinstance(value, (list, tuple)):
  19. return self.range_type(value[0], value[1])
  20. return value
  21. def to_python(self, value):
  22. if isinstance(value, six.string_types):
  23. # Assume we're deserializing
  24. vals = json.loads(value)
  25. for end in ('lower', 'upper'):
  26. if end in vals:
  27. vals[end] = self.base_field.to_python(vals[end])
  28. value = self.range_type(**vals)
  29. elif isinstance(value, (list, tuple)):
  30. value = self.range_type(value[0], value[1])
  31. return value
  32. def set_attributes_from_name(self, name):
  33. super(RangeField, self).set_attributes_from_name(name)
  34. self.base_field.set_attributes_from_name(name)
  35. def value_to_string(self, obj):
  36. value = self.value_from_object(obj)
  37. if value is None:
  38. return None
  39. if value.isempty:
  40. return json.dumps({"empty": True})
  41. base_field = self.base_field
  42. result = {"bounds": value._bounds}
  43. for end in ('lower', 'upper'):
  44. val = getattr(value, end)
  45. if val is None:
  46. result[end] = None
  47. else:
  48. obj = AttributeSetter(base_field.attname, val)
  49. result[end] = base_field.value_to_string(obj)
  50. return json.dumps(result)
  51. def formfield(self, **kwargs):
  52. kwargs.setdefault('form_class', self.form_field)
  53. return super(RangeField, self).formfield(**kwargs)
  54. class IntegerRangeField(RangeField):
  55. base_field = models.IntegerField()
  56. range_type = NumericRange
  57. form_field = forms.IntegerRangeField
  58. def db_type(self, connection):
  59. return 'int4range'
  60. class BigIntegerRangeField(RangeField):
  61. base_field = models.BigIntegerField()
  62. range_type = NumericRange
  63. form_field = forms.IntegerRangeField
  64. def db_type(self, connection):
  65. return 'int8range'
  66. class FloatRangeField(RangeField):
  67. base_field = models.FloatField()
  68. range_type = NumericRange
  69. form_field = forms.FloatRangeField
  70. def db_type(self, connection):
  71. return 'numrange'
  72. class DateTimeRangeField(RangeField):
  73. base_field = models.DateTimeField()
  74. range_type = DateTimeTZRange
  75. form_field = forms.DateTimeRangeField
  76. def db_type(self, connection):
  77. return 'tstzrange'
  78. class DateRangeField(RangeField):
  79. base_field = models.DateField()
  80. range_type = DateRange
  81. form_field = forms.DateRangeField
  82. def db_type(self, connection):
  83. return 'daterange'
  84. RangeField.register_lookup(lookups.DataContains)
  85. RangeField.register_lookup(lookups.ContainedBy)
  86. RangeField.register_lookup(lookups.Overlap)
  87. class RangeContainedBy(models.Lookup):
  88. lookup_name = 'contained_by'
  89. type_mapping = {
  90. 'integer': 'int4range',
  91. 'bigint': 'int8range',
  92. 'double precision': 'numrange',
  93. 'date': 'daterange',
  94. 'timestamp with time zone': 'tstzrange',
  95. }
  96. def as_sql(self, qn, connection):
  97. field = self.lhs.output_field
  98. if isinstance(field, models.FloatField):
  99. sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
  100. else:
  101. sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
  102. lhs, lhs_params = self.process_lhs(qn, connection)
  103. rhs, rhs_params = self.process_rhs(qn, connection)
  104. params = lhs_params + rhs_params
  105. return sql % (lhs, rhs), params
  106. def get_prep_lookup(self):
  107. return RangeField().get_prep_lookup(self.lookup_name, self.rhs)
  108. models.DateField.register_lookup(RangeContainedBy)
  109. models.DateTimeField.register_lookup(RangeContainedBy)
  110. models.IntegerField.register_lookup(RangeContainedBy)
  111. models.BigIntegerField.register_lookup(RangeContainedBy)
  112. models.FloatField.register_lookup(RangeContainedBy)
  113. @RangeField.register_lookup
  114. class FullyLessThan(lookups.PostgresSimpleLookup):
  115. lookup_name = 'fully_lt'
  116. operator = '<<'
  117. @RangeField.register_lookup
  118. class FullGreaterThan(lookups.PostgresSimpleLookup):
  119. lookup_name = 'fully_gt'
  120. operator = '>>'
  121. @RangeField.register_lookup
  122. class NotLessThan(lookups.PostgresSimpleLookup):
  123. lookup_name = 'not_lt'
  124. operator = '&>'
  125. @RangeField.register_lookup
  126. class NotGreaterThan(lookups.PostgresSimpleLookup):
  127. lookup_name = 'not_gt'
  128. operator = '&<'
  129. @RangeField.register_lookup
  130. class AdjacentToLookup(lookups.PostgresSimpleLookup):
  131. lookup_name = 'adjacent_to'
  132. operator = '-|-'
  133. @RangeField.register_lookup
  134. class RangeStartsWith(models.Transform):
  135. lookup_name = 'startswith'
  136. function = 'lower'
  137. @property
  138. def output_field(self):
  139. return self.lhs.output_field.base_field
  140. @RangeField.register_lookup
  141. class RangeEndsWith(models.Transform):
  142. lookup_name = 'endswith'
  143. function = 'upper'
  144. @property
  145. def output_field(self):
  146. return self.lhs.output_field.base_field
  147. @RangeField.register_lookup
  148. class IsEmpty(models.Transform):
  149. lookup_name = 'isempty'
  150. function = 'isempty'
  151. output_field = models.BooleanField()