You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

array.py 9.0 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, IntegerField, Transform
  7. from django.db.models.lookups import Exact
  8. from django.utils import six
  9. from django.utils.translation import string_concat, ugettext_lazy as _
  10. from .utils import AttributeSetter
  11. __all__ = ['ArrayField']
  12. class ArrayField(Field):
  13. empty_strings_allowed = False
  14. default_error_messages = {
  15. 'item_invalid': _('Item %(nth)s in the array did not validate: '),
  16. 'nested_array_mismatch': _('Nested arrays must have the same length.'),
  17. }
  18. def __init__(self, base_field, size=None, **kwargs):
  19. self.base_field = base_field
  20. self.size = size
  21. if self.size:
  22. self.default_validators = self.default_validators[:]
  23. self.default_validators.append(ArrayMaxLengthValidator(self.size))
  24. super(ArrayField, self).__init__(**kwargs)
  25. @property
  26. def model(self):
  27. try:
  28. return self.__dict__['model']
  29. except KeyError:
  30. raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
  31. @model.setter
  32. def model(self, model):
  33. self.__dict__['model'] = model
  34. self.base_field.model = model
  35. def check(self, **kwargs):
  36. errors = super(ArrayField, self).check(**kwargs)
  37. if self.base_field.remote_field:
  38. errors.append(
  39. checks.Error(
  40. 'Base field for array cannot be a related field.',
  41. hint=None,
  42. obj=self,
  43. id='postgres.E002'
  44. )
  45. )
  46. else:
  47. # Remove the field name checks as they are not needed here.
  48. base_errors = self.base_field.check()
  49. if base_errors:
  50. messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
  51. errors.append(
  52. checks.Error(
  53. 'Base field for array has errors:\n %s' % messages,
  54. hint=None,
  55. obj=self,
  56. id='postgres.E001'
  57. )
  58. )
  59. return errors
  60. def set_attributes_from_name(self, name):
  61. super(ArrayField, self).set_attributes_from_name(name)
  62. self.base_field.set_attributes_from_name(name)
  63. @property
  64. def description(self):
  65. return 'Array of %s' % self.base_field.description
  66. def db_type(self, connection):
  67. size = self.size or ''
  68. return '%s[%s]' % (self.base_field.db_type(connection), size)
  69. def get_db_prep_value(self, value, connection, prepared=False):
  70. if isinstance(value, list) or isinstance(value, tuple):
  71. return [self.base_field.get_db_prep_value(i, connection, prepared) for i in value]
  72. return value
  73. def deconstruct(self):
  74. name, path, args, kwargs = super(ArrayField, self).deconstruct()
  75. if path == 'django.contrib.postgres.fields.array.ArrayField':
  76. path = 'django.contrib.postgres.fields.ArrayField'
  77. kwargs.update({
  78. 'base_field': self.base_field,
  79. 'size': self.size,
  80. })
  81. return name, path, args, kwargs
  82. def to_python(self, value):
  83. if isinstance(value, six.string_types):
  84. # Assume we're deserializing
  85. vals = json.loads(value)
  86. value = [self.base_field.to_python(val) for val in vals]
  87. return value
  88. def value_to_string(self, obj):
  89. values = []
  90. vals = self.value_from_object(obj)
  91. base_field = self.base_field
  92. for val in vals:
  93. if val is None:
  94. values.append(None)
  95. else:
  96. obj = AttributeSetter(base_field.attname, val)
  97. values.append(base_field.value_to_string(obj))
  98. return json.dumps(values)
  99. def get_transform(self, name):
  100. transform = super(ArrayField, self).get_transform(name)
  101. if transform:
  102. return transform
  103. try:
  104. index = int(name)
  105. except ValueError:
  106. pass
  107. else:
  108. index += 1 # postgres uses 1-indexing
  109. return IndexTransformFactory(index, self.base_field)
  110. try:
  111. start, end = name.split('_')
  112. start = int(start) + 1
  113. end = int(end) # don't add one here because postgres slices are weird
  114. except ValueError:
  115. pass
  116. else:
  117. return SliceTransformFactory(start, end)
  118. def validate(self, value, model_instance):
  119. super(ArrayField, self).validate(value, model_instance)
  120. for i, part in enumerate(value):
  121. try:
  122. self.base_field.validate(part, model_instance)
  123. except exceptions.ValidationError as e:
  124. raise exceptions.ValidationError(
  125. string_concat(self.error_messages['item_invalid'], e.message),
  126. code='item_invalid',
  127. params={'nth': i},
  128. )
  129. if isinstance(self.base_field, ArrayField):
  130. if len({len(i) for i in value}) > 1:
  131. raise exceptions.ValidationError(
  132. self.error_messages['nested_array_mismatch'],
  133. code='nested_array_mismatch',
  134. )
  135. def run_validators(self, value):
  136. super(ArrayField, self).run_validators(value)
  137. for i, part in enumerate(value):
  138. try:
  139. self.base_field.run_validators(part)
  140. except exceptions.ValidationError as e:
  141. raise exceptions.ValidationError(
  142. string_concat(self.error_messages['item_invalid'], ' '.join(e.messages)),
  143. code='item_invalid',
  144. params={'nth': i},
  145. )
  146. def formfield(self, **kwargs):
  147. defaults = {
  148. 'form_class': SimpleArrayField,
  149. 'base_field': self.base_field.formfield(),
  150. 'max_length': self.size,
  151. }
  152. defaults.update(kwargs)
  153. return super(ArrayField, self).formfield(**defaults)
  154. @ArrayField.register_lookup
  155. class ArrayContains(lookups.DataContains):
  156. def as_sql(self, qn, connection):
  157. sql, params = super(ArrayContains, self).as_sql(qn, connection)
  158. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  159. return sql, params
  160. @ArrayField.register_lookup
  161. class ArrayContainedBy(lookups.ContainedBy):
  162. def as_sql(self, qn, connection):
  163. sql, params = super(ArrayContainedBy, self).as_sql(qn, connection)
  164. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  165. return sql, params
  166. @ArrayField.register_lookup
  167. class ArrayExact(Exact):
  168. def as_sql(self, qn, connection):
  169. sql, params = super(ArrayExact, self).as_sql(qn, connection)
  170. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  171. return sql, params
  172. @ArrayField.register_lookup
  173. class ArrayOverlap(lookups.Overlap):
  174. def as_sql(self, qn, connection):
  175. sql, params = super(ArrayOverlap, self).as_sql(qn, connection)
  176. sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
  177. return sql, params
  178. @ArrayField.register_lookup
  179. class ArrayLenTransform(Transform):
  180. lookup_name = 'len'
  181. output_field = IntegerField()
  182. def as_sql(self, compiler, connection):
  183. lhs, params = compiler.compile(self.lhs)
  184. # Distinguish NULL and empty arrays
  185. return (
  186. 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
  187. 'coalesce(array_length(%(lhs)s, 1), 0) END'
  188. ) % {'lhs': lhs}, params
  189. class IndexTransform(Transform):
  190. def __init__(self, index, base_field, *args, **kwargs):
  191. super(IndexTransform, self).__init__(*args, **kwargs)
  192. self.index = index
  193. self.base_field = base_field
  194. def as_sql(self, compiler, connection):
  195. lhs, params = compiler.compile(self.lhs)
  196. return '%s[%s]' % (lhs, self.index), params
  197. @property
  198. def output_field(self):
  199. return self.base_field
  200. class IndexTransformFactory(object):
  201. def __init__(self, index, base_field):
  202. self.index = index
  203. self.base_field = base_field
  204. def __call__(self, *args, **kwargs):
  205. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  206. class SliceTransform(Transform):
  207. def __init__(self, start, end, *args, **kwargs):
  208. super(SliceTransform, self).__init__(*args, **kwargs)
  209. self.start = start
  210. self.end = end
  211. def as_sql(self, compiler, connection):
  212. lhs, params = compiler.compile(self.lhs)
  213. return '%s[%s:%s]' % (lhs, self.start, self.end), params
  214. class SliceTransformFactory(object):
  215. def __init__(self, start, end):
  216. self.start = start
  217. self.end = end
  218. def __call__(self, *args, **kwargs):
  219. return SliceTransform(self.start, self.end, *args, **kwargs)