You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

jsonb.py 2.9 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import json
  2. from psycopg2.extras import Json
  3. from django.contrib.postgres import forms, lookups
  4. from django.core import exceptions
  5. from django.db.models import Field, Transform
  6. from django.utils.translation import ugettext_lazy as _
  7. __all__ = ['JSONField']
  8. class JSONField(Field):
  9. empty_strings_allowed = False
  10. description = _('A JSON object')
  11. default_error_messages = {
  12. 'invalid': _("Value must be valid JSON."),
  13. }
  14. def db_type(self, connection):
  15. return 'jsonb'
  16. def get_transform(self, name):
  17. transform = super(JSONField, self).get_transform(name)
  18. if transform:
  19. return transform
  20. return KeyTransformFactory(name)
  21. def get_prep_value(self, value):
  22. if value is not None:
  23. return Json(value)
  24. return value
  25. def get_prep_lookup(self, lookup_type, value):
  26. if lookup_type in ('has_key', 'has_keys', 'has_any_keys'):
  27. return value
  28. if isinstance(value, (dict, list)):
  29. return Json(value)
  30. return super(JSONField, self).get_prep_lookup(lookup_type, value)
  31. def validate(self, value, model_instance):
  32. super(JSONField, self).validate(value, model_instance)
  33. try:
  34. json.dumps(value)
  35. except TypeError:
  36. raise exceptions.ValidationError(
  37. self.error_messages['invalid'],
  38. code='invalid',
  39. params={'value': value},
  40. )
  41. def value_to_string(self, obj):
  42. value = self.value_from_object(obj)
  43. return value
  44. def formfield(self, **kwargs):
  45. defaults = {'form_class': forms.JSONField}
  46. defaults.update(kwargs)
  47. return super(JSONField, self).formfield(**defaults)
  48. JSONField.register_lookup(lookups.DataContains)
  49. JSONField.register_lookup(lookups.ContainedBy)
  50. JSONField.register_lookup(lookups.HasKey)
  51. JSONField.register_lookup(lookups.HasKeys)
  52. JSONField.register_lookup(lookups.HasAnyKeys)
  53. class KeyTransform(Transform):
  54. def __init__(self, key_name, *args, **kwargs):
  55. super(KeyTransform, self).__init__(*args, **kwargs)
  56. self.key_name = key_name
  57. def as_sql(self, compiler, connection):
  58. key_transforms = [self.key_name]
  59. previous = self.lhs
  60. while isinstance(previous, KeyTransform):
  61. key_transforms.insert(0, previous.key_name)
  62. previous = previous.lhs
  63. lhs, params = compiler.compile(previous)
  64. if len(key_transforms) > 1:
  65. return "{} #> %s".format(lhs), [key_transforms] + params
  66. try:
  67. int(self.key_name)
  68. except ValueError:
  69. lookup = "'%s'" % self.key_name
  70. else:
  71. lookup = "%s" % self.key_name
  72. return "%s -> %s" % (lhs, lookup), params
  73. class KeyTransformFactory(object):
  74. def __init__(self, key_name):
  75. self.key_name = key_name
  76. def __call__(self, *args, **kwargs):
  77. return KeyTransform(self.key_name, *args, **kwargs)