You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

299 lines
11 KiB

  1. from __future__ import unicode_literals
  2. from django.db.models.fields import NOT_PROVIDED
  3. from django.utils import six
  4. from django.utils.functional import cached_property
  5. from .base import Operation
  6. class AddField(Operation):
  7. """
  8. Adds a field to a model.
  9. """
  10. def __init__(self, model_name, name, field, preserve_default=True):
  11. self.model_name = model_name
  12. self.name = name
  13. self.field = field
  14. self.preserve_default = preserve_default
  15. @cached_property
  16. def name_lower(self):
  17. return self.name.lower()
  18. @cached_property
  19. def model_name_lower(self):
  20. return self.model_name.lower()
  21. def deconstruct(self):
  22. kwargs = {
  23. 'model_name': self.model_name,
  24. 'name': self.name,
  25. 'field': self.field,
  26. }
  27. if self.preserve_default is not True:
  28. kwargs['preserve_default'] = self.preserve_default
  29. return (
  30. self.__class__.__name__,
  31. [],
  32. kwargs
  33. )
  34. def state_forwards(self, app_label, state):
  35. # If preserve default is off, don't use the default for future state
  36. if not self.preserve_default:
  37. field = self.field.clone()
  38. field.default = NOT_PROVIDED
  39. else:
  40. field = self.field
  41. state.models[app_label, self.model_name_lower].fields.append((self.name, field))
  42. state.reload_model(app_label, self.model_name_lower)
  43. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  44. to_model = to_state.apps.get_model(app_label, self.model_name)
  45. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  46. from_model = from_state.apps.get_model(app_label, self.model_name)
  47. field = to_model._meta.get_field(self.name)
  48. if not self.preserve_default:
  49. field.default = self.field.default
  50. schema_editor.add_field(
  51. from_model,
  52. field,
  53. )
  54. if not self.preserve_default:
  55. field.default = NOT_PROVIDED
  56. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  57. from_model = from_state.apps.get_model(app_label, self.model_name)
  58. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  59. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  60. def describe(self):
  61. return "Add field %s to %s" % (self.name, self.model_name)
  62. def references_model(self, name, app_label=None):
  63. return name.lower() == self.model_name_lower
  64. def references_field(self, model_name, name, app_label=None):
  65. return self.references_model(model_name) and name.lower() == self.name_lower
  66. class RemoveField(Operation):
  67. """
  68. Removes a field from a model.
  69. """
  70. def __init__(self, model_name, name):
  71. self.model_name = model_name
  72. self.name = name
  73. @cached_property
  74. def name_lower(self):
  75. return self.name.lower()
  76. @cached_property
  77. def model_name_lower(self):
  78. return self.model_name.lower()
  79. def deconstruct(self):
  80. kwargs = {
  81. 'model_name': self.model_name,
  82. 'name': self.name,
  83. }
  84. return (
  85. self.__class__.__name__,
  86. [],
  87. kwargs
  88. )
  89. def state_forwards(self, app_label, state):
  90. new_fields = []
  91. for name, instance in state.models[app_label, self.model_name_lower].fields:
  92. if name != self.name:
  93. new_fields.append((name, instance))
  94. state.models[app_label, self.model_name_lower].fields = new_fields
  95. state.reload_model(app_label, self.model_name_lower)
  96. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  97. from_model = from_state.apps.get_model(app_label, self.model_name)
  98. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  99. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  100. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  101. to_model = to_state.apps.get_model(app_label, self.model_name)
  102. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  103. from_model = from_state.apps.get_model(app_label, self.model_name)
  104. schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
  105. def describe(self):
  106. return "Remove field %s from %s" % (self.name, self.model_name)
  107. def references_model(self, name, app_label=None):
  108. return name.lower() == self.model_name_lower
  109. def references_field(self, model_name, name, app_label=None):
  110. return self.references_model(model_name) and name.lower() == self.name_lower
  111. class AlterField(Operation):
  112. """
  113. Alters a field's database column (e.g. null, max_length) to the provided new field
  114. """
  115. def __init__(self, model_name, name, field, preserve_default=True):
  116. self.model_name = model_name
  117. self.name = name
  118. self.field = field
  119. self.preserve_default = preserve_default
  120. @cached_property
  121. def name_lower(self):
  122. return self.name.lower()
  123. @cached_property
  124. def model_name_lower(self):
  125. return self.model_name.lower()
  126. def deconstruct(self):
  127. kwargs = {
  128. 'model_name': self.model_name,
  129. 'name': self.name,
  130. 'field': self.field,
  131. }
  132. if self.preserve_default is not True:
  133. kwargs['preserve_default'] = self.preserve_default
  134. return (
  135. self.__class__.__name__,
  136. [],
  137. kwargs
  138. )
  139. def state_forwards(self, app_label, state):
  140. if not self.preserve_default:
  141. field = self.field.clone()
  142. field.default = NOT_PROVIDED
  143. else:
  144. field = self.field
  145. state.models[app_label, self.model_name_lower].fields = [
  146. (n, field if n == self.name else f)
  147. for n, f in
  148. state.models[app_label, self.model_name_lower].fields
  149. ]
  150. state.reload_model(app_label, self.model_name_lower)
  151. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  152. to_model = to_state.apps.get_model(app_label, self.model_name)
  153. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  154. from_model = from_state.apps.get_model(app_label, self.model_name)
  155. from_field = from_model._meta.get_field(self.name)
  156. to_field = to_model._meta.get_field(self.name)
  157. # If the field is a relatedfield with an unresolved rel.to, just
  158. # set it equal to the other field side. Bandaid fix for AlterField
  159. # migrations that are part of a RenameModel change.
  160. if from_field.remote_field and from_field.remote_field.model:
  161. if isinstance(from_field.remote_field.model, six.string_types):
  162. from_field.remote_field.model = to_field.remote_field.model
  163. elif to_field.remote_field and isinstance(to_field.remote_field.model, six.string_types):
  164. to_field.remote_field.model = from_field.remote_field.model
  165. if not self.preserve_default:
  166. to_field.default = self.field.default
  167. schema_editor.alter_field(from_model, from_field, to_field)
  168. if not self.preserve_default:
  169. to_field.default = NOT_PROVIDED
  170. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  171. self.database_forwards(app_label, schema_editor, from_state, to_state)
  172. def describe(self):
  173. return "Alter field %s on %s" % (self.name, self.model_name)
  174. def references_model(self, name, app_label=None):
  175. return name.lower() == self.model_name_lower
  176. def references_field(self, model_name, name, app_label=None):
  177. return self.references_model(model_name) and name.lower() == self.name_lower
  178. class RenameField(Operation):
  179. """
  180. Renames a field on the model. Might affect db_column too.
  181. """
  182. def __init__(self, model_name, old_name, new_name):
  183. self.model_name = model_name
  184. self.old_name = old_name
  185. self.new_name = new_name
  186. @cached_property
  187. def old_name_lower(self):
  188. return self.old_name.lower()
  189. @cached_property
  190. def new_name_lower(self):
  191. return self.new_name.lower()
  192. @cached_property
  193. def model_name_lower(self):
  194. return self.model_name.lower()
  195. def deconstruct(self):
  196. kwargs = {
  197. 'model_name': self.model_name,
  198. 'old_name': self.old_name,
  199. 'new_name': self.new_name,
  200. }
  201. return (
  202. self.__class__.__name__,
  203. [],
  204. kwargs
  205. )
  206. def state_forwards(self, app_label, state):
  207. # Rename the field
  208. state.models[app_label, self.model_name_lower].fields = [
  209. (self.new_name if n == self.old_name else n, f)
  210. for n, f in state.models[app_label, self.model_name_lower].fields
  211. ]
  212. # Fix index/unique_together to refer to the new field
  213. options = state.models[app_label, self.model_name_lower].options
  214. for option in ('index_together', 'unique_together'):
  215. if option in options:
  216. options[option] = [
  217. [self.new_name if n == self.old_name else n for n in together]
  218. for together in options[option]
  219. ]
  220. state.reload_model(app_label, self.model_name_lower)
  221. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  222. to_model = to_state.apps.get_model(app_label, self.model_name)
  223. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  224. from_model = from_state.apps.get_model(app_label, self.model_name)
  225. schema_editor.alter_field(
  226. from_model,
  227. from_model._meta.get_field(self.old_name),
  228. to_model._meta.get_field(self.new_name),
  229. )
  230. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  231. to_model = to_state.apps.get_model(app_label, self.model_name)
  232. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  233. from_model = from_state.apps.get_model(app_label, self.model_name)
  234. schema_editor.alter_field(
  235. from_model,
  236. from_model._meta.get_field(self.new_name),
  237. to_model._meta.get_field(self.old_name),
  238. )
  239. def describe(self):
  240. return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
  241. def references_model(self, name, app_label=None):
  242. return name.lower() == self.model_name_lower
  243. def references_field(self, model_name, name, app_label=None):
  244. return self.references_model(model_name) and (
  245. name.lower() == self.old_name_lower or
  246. name.lower() == self.new_name_lower
  247. )