fields.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. from collections import defaultdict
  2. from django.contrib.contenttypes.fields import GenericForeignKey
  3. from django.contrib.contenttypes.models import ContentType
  4. from django.core.exceptions import ObjectDoesNotExist
  5. from django.db import models
  6. from django.db.models.fields.mixins import FieldCacheMixin
  7. from django.utils.functional import cached_property
  8. from django.utils.safestring import mark_safe
  9. from django.utils.translation import gettext_lazy as _
  10. from .forms.widgets import ColorSelect
  11. from .validators import ColorValidator
  12. __all__ = (
  13. 'ColorField',
  14. 'CounterCacheField',
  15. 'GenericArrayForeignKey',
  16. 'NaturalOrderingField',
  17. 'RestrictedGenericForeignKey',
  18. )
  19. class ColorField(models.CharField):
  20. default_validators = [ColorValidator]
  21. description = "A hexadecimal RGB color code"
  22. def __init__(self, *args, **kwargs):
  23. kwargs['max_length'] = 6
  24. super().__init__(*args, **kwargs)
  25. def formfield(self, **kwargs):
  26. kwargs['widget'] = ColorSelect
  27. kwargs['help_text'] = mark_safe(_('RGB color in hexadecimal. Example: ') + '<code>00ff00</code>')
  28. return super().formfield(**kwargs)
  29. class NaturalOrderingField(models.CharField):
  30. """
  31. A field which stores a naturalized representation of its target field, to be used for ordering its parent model.
  32. :param target_field: Name of the field of the parent model to be naturalized
  33. :param naturalize_function: The function used to generate a naturalized value (optional)
  34. """
  35. description = "Stores a representation of its target field suitable for natural ordering"
  36. def __init__(self, target_field, naturalize_function, *args, **kwargs):
  37. self.target_field = target_field
  38. self.naturalize_function = naturalize_function
  39. super().__init__(*args, **kwargs)
  40. def pre_save(self, model_instance, add):
  41. """
  42. Generate a naturalized value from the target field
  43. """
  44. original_value = getattr(model_instance, self.target_field)
  45. naturalized_value = self.naturalize_function(original_value, max_length=self.max_length)
  46. setattr(model_instance, self.attname, naturalized_value)
  47. return naturalized_value
  48. def deconstruct(self):
  49. kwargs = super().deconstruct()[3] # Pass kwargs from CharField
  50. kwargs['naturalize_function'] = self.naturalize_function
  51. return (
  52. self.name,
  53. 'utilities.fields.NaturalOrderingField',
  54. [self.target_field],
  55. kwargs,
  56. )
  57. class RestrictedGenericForeignKey(GenericForeignKey):
  58. # Replicated largely from GenericForeignKey. Changes include:
  59. # 1. Capture restrict_params from RestrictedPrefetch (hack)
  60. # 2. If restrict_params is set, call restrict() on the queryset for
  61. # the related model
  62. def get_prefetch_querysets(self, instances, querysets=None):
  63. restrict_params = {}
  64. custom_queryset_dict = {}
  65. # Compensate for the hack in RestrictedPrefetch
  66. if type(querysets) is dict:
  67. restrict_params = querysets
  68. elif querysets is not None:
  69. for queryset in querysets:
  70. ct_id = self.get_content_type(
  71. model=queryset.query.model, using=queryset.db
  72. ).pk
  73. if ct_id in custom_queryset_dict:
  74. raise ValueError(
  75. "Only one queryset is allowed for each content type."
  76. )
  77. custom_queryset_dict[ct_id] = queryset
  78. # For efficiency, group the instances by content type and then do one
  79. # query per model
  80. fk_dict = defaultdict(set)
  81. # We need one instance for each group in order to get the right db:
  82. instance_dict = {}
  83. ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
  84. for instance in instances:
  85. # We avoid looking for values if either ct_id or fkey value is None
  86. ct_id = getattr(instance, ct_attname)
  87. if ct_id is not None:
  88. # Check if the content type actually exists
  89. if not self.get_content_type(id=ct_id, using=instance._state.db).model_class():
  90. continue
  91. fk_val = getattr(instance, self.fk_field)
  92. if fk_val is not None:
  93. fk_dict[ct_id].add(fk_val)
  94. instance_dict[ct_id] = instance
  95. ret_val = []
  96. for ct_id, fkeys in fk_dict.items():
  97. if ct_id in custom_queryset_dict:
  98. # Return values from the custom queryset, if provided.
  99. ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
  100. else:
  101. instance = instance_dict[ct_id]
  102. ct = self.get_content_type(id=ct_id, using=instance._state.db)
  103. qs = ct.model_class().objects.filter(pk__in=fkeys)
  104. if restrict_params:
  105. qs = qs.restrict(**restrict_params)
  106. ret_val.extend(qs)
  107. # For doing the join in Python, we have to match both the FK val and the
  108. # content type, so we use a callable that returns a (fk, class) pair.
  109. def gfk_key(obj):
  110. ct_id = getattr(obj, ct_attname)
  111. if ct_id is None:
  112. return None
  113. else:
  114. if model := self.get_content_type(
  115. id=ct_id, using=obj._state.db
  116. ).model_class():
  117. return (
  118. model._meta.pk.get_prep_value(getattr(obj, self.fk_field)),
  119. model,
  120. )
  121. return None
  122. return (
  123. ret_val,
  124. lambda obj: (obj.pk, obj.__class__),
  125. gfk_key,
  126. True,
  127. self.name,
  128. False,
  129. )
  130. class CounterCacheField(models.BigIntegerField):
  131. """
  132. Counter field to keep track of related model counts.
  133. """
  134. def __init__(self, to_model, to_field, *args, **kwargs):
  135. if not isinstance(to_model, str):
  136. raise TypeError(
  137. _("%s(%r) is invalid. to_model parameter to CounterCacheField must be "
  138. "a string in the format 'app.model'")
  139. % (
  140. self.__class__.__name__,
  141. to_model,
  142. )
  143. )
  144. if not isinstance(to_field, str):
  145. raise TypeError(
  146. _("%s(%r) is invalid. to_field parameter to CounterCacheField must be "
  147. "a string in the format 'field'")
  148. % (
  149. self.__class__.__name__,
  150. to_field,
  151. )
  152. )
  153. self.to_model_name = to_model
  154. self.to_field_name = to_field
  155. kwargs['default'] = kwargs.get('default', 0)
  156. kwargs['editable'] = False
  157. super().__init__(*args, **kwargs)
  158. def deconstruct(self):
  159. name, path, args, kwargs = super().deconstruct()
  160. kwargs["to_model"] = self.to_model_name
  161. kwargs["to_field"] = self.to_field_name
  162. return name, path, args, kwargs
  163. class GenericArrayForeignKey(FieldCacheMixin, models.Field):
  164. """
  165. Provide a generic many-to-many relation through an 2d array field
  166. """
  167. many_to_many = False
  168. many_to_one = False
  169. one_to_many = True
  170. one_to_one = False
  171. def __init__(self, field, for_concrete_model=True):
  172. super().__init__(editable=False)
  173. self.field = field
  174. self.for_concrete_model = for_concrete_model
  175. self.is_relation = True
  176. def contribute_to_class(self, cls, name, **kwargs):
  177. super().contribute_to_class(cls, name, private_only=True, **kwargs)
  178. # GenericArrayForeignKey is its own descriptor.
  179. setattr(cls, self.attname, self)
  180. @cached_property
  181. def cache_name(self):
  182. return self.name
  183. def get_cache_name(self):
  184. return self.cache_name
  185. def _get_ids(self, instance):
  186. return getattr(instance, self.field)
  187. def get_content_type_by_id(self, id=None, using=None):
  188. return ContentType.objects.db_manager(using).get_for_id(id)
  189. def get_content_type_of_obj(self, obj=None):
  190. return ContentType.objects.db_manager(obj._state.db).get_for_model(
  191. obj, for_concrete_model=self.for_concrete_model
  192. )
  193. def get_content_type_for_model(self, using=None, model=None):
  194. return ContentType.objects.db_manager(using).get_for_model(
  195. model, for_concrete_model=self.for_concrete_model
  196. )
  197. def get_prefetch_querysets(self, instances, querysets=None):
  198. custom_queryset_dict = {}
  199. if querysets is not None:
  200. for queryset in querysets:
  201. ct_id = self.get_content_type_for_model(
  202. model=queryset.query.model, using=queryset.db
  203. ).pk
  204. if ct_id in custom_queryset_dict:
  205. raise ValueError(
  206. "Only one queryset is allowed for each content type."
  207. )
  208. custom_queryset_dict[ct_id] = queryset
  209. # For efficiency, group the instances by content type and then do one
  210. # query per model
  211. fk_dict = defaultdict(set) # type id, db -> model ids
  212. for instance in instances:
  213. for step in self._get_ids(instance):
  214. for ct_id, fk_val in step:
  215. fk_dict[(ct_id, instance._state.db)].add(fk_val)
  216. rel_objects = []
  217. for (ct_id, db), fkeys in fk_dict.items():
  218. if ct_id in custom_queryset_dict:
  219. rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
  220. else:
  221. ct = self.get_content_type_by_id(id=ct_id, using=db)
  222. rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
  223. # reorganize objects to fix usage
  224. items = {
  225. (self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
  226. for rel_obj in rel_objects
  227. }
  228. lists = []
  229. lists_keys = {}
  230. for instance in instances:
  231. data = []
  232. lists.append(data)
  233. lists_keys[instance] = id(data)
  234. for step in self._get_ids(instance):
  235. nodes = []
  236. for ct, fk in step:
  237. if rel_obj := items.get((ct, fk, instance._state.db)):
  238. nodes.append(rel_obj)
  239. data.append(nodes)
  240. return (
  241. lists,
  242. lambda obj: id(obj),
  243. lambda obj: lists_keys[obj],
  244. True,
  245. self.cache_name,
  246. False,
  247. )
  248. def __get__(self, instance, cls=None):
  249. if instance is None:
  250. return self
  251. rel_objects = self.get_cached_value(instance, default=...)
  252. expected_ids = self._get_ids(instance)
  253. # we do not check if cache actual
  254. if rel_objects is not ...:
  255. return rel_objects
  256. # load value
  257. if expected_ids is None:
  258. self.set_cached_value(instance, rel_objects)
  259. return rel_objects
  260. data = []
  261. for step in self._get_ids(instance):
  262. rel_objects = []
  263. for ct_id, pk_val in step:
  264. ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
  265. try:
  266. rel_obj = ct.get_object_for_this_type(pk=pk_val)
  267. rel_objects.append(rel_obj)
  268. except ObjectDoesNotExist:
  269. pass
  270. data.append(rel_objects)
  271. self.set_cached_value(instance, data)
  272. return data