fields.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. from collections import defaultdict
  2. from django.contrib.contenttypes.fields import GenericForeignKey
  3. from django.contrib.contenttypes.models import ContentType
  4. from django.core.exceptions import ObjectDoesNotExist
  5. from django.db import models
  6. from django.db.models.fields.mixins import FieldCacheMixin
  7. from django.utils.functional import cached_property
  8. from django.utils.translation import gettext_lazy as _
  9. from .forms.widgets import ColorSelect
  10. from .validators import ColorValidator
  11. __all__ = (
  12. 'ColorField',
  13. 'CounterCacheField',
  14. 'GenericArrayForeignKey',
  15. 'NaturalOrderingField',
  16. 'RestrictedGenericForeignKey',
  17. )
  18. class ColorField(models.CharField):
  19. default_validators = [ColorValidator]
  20. description = "A hexadecimal RGB color code"
  21. def __init__(self, *args, **kwargs):
  22. kwargs['max_length'] = 6
  23. super().__init__(*args, **kwargs)
  24. def formfield(self, **kwargs):
  25. kwargs['widget'] = ColorSelect
  26. return super().formfield(**kwargs)
  27. class NaturalOrderingField(models.CharField):
  28. """
  29. A field which stores a naturalized representation of its target field, to be used for ordering its parent model.
  30. :param target_field: Name of the field of the parent model to be naturalized
  31. :param naturalize_function: The function used to generate a naturalized value (optional)
  32. """
  33. description = "Stores a representation of its target field suitable for natural ordering"
  34. def __init__(self, target_field, naturalize_function, *args, **kwargs):
  35. self.target_field = target_field
  36. self.naturalize_function = naturalize_function
  37. super().__init__(*args, **kwargs)
  38. def pre_save(self, model_instance, add):
  39. """
  40. Generate a naturalized value from the target field
  41. """
  42. original_value = getattr(model_instance, self.target_field)
  43. naturalized_value = self.naturalize_function(original_value, max_length=self.max_length)
  44. setattr(model_instance, self.attname, naturalized_value)
  45. return naturalized_value
  46. def deconstruct(self):
  47. kwargs = super().deconstruct()[3] # Pass kwargs from CharField
  48. kwargs['naturalize_function'] = self.naturalize_function
  49. return (
  50. self.name,
  51. 'utilities.fields.NaturalOrderingField',
  52. [self.target_field],
  53. kwargs,
  54. )
  55. class RestrictedGenericForeignKey(GenericForeignKey):
  56. # Replicated largely from GenericForeignKey. Changes include:
  57. # 1. Capture restrict_params from RestrictedPrefetch (hack)
  58. # 2. If restrict_params is set, call restrict() on the queryset for
  59. # the related model
  60. def get_prefetch_querysets(self, instances, querysets=None):
  61. restrict_params = {}
  62. custom_queryset_dict = {}
  63. # Compensate for the hack in RestrictedPrefetch
  64. if type(querysets) is dict:
  65. restrict_params = querysets
  66. elif querysets is not None:
  67. for queryset in querysets:
  68. ct_id = self.get_content_type(
  69. model=queryset.query.model, using=queryset.db
  70. ).pk
  71. if ct_id in custom_queryset_dict:
  72. raise ValueError(
  73. "Only one queryset is allowed for each content type."
  74. )
  75. custom_queryset_dict[ct_id] = queryset
  76. # For efficiency, group the instances by content type and then do one
  77. # query per model
  78. fk_dict = defaultdict(set)
  79. # We need one instance for each group in order to get the right db:
  80. instance_dict = {}
  81. ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
  82. for instance in instances:
  83. # We avoid looking for values if either ct_id or fkey value is None
  84. ct_id = getattr(instance, ct_attname)
  85. if ct_id is not None:
  86. # Check if the content type actually exists
  87. if not self.get_content_type(id=ct_id, using=instance._state.db).model_class():
  88. continue
  89. fk_val = getattr(instance, self.fk_field)
  90. if fk_val is not None:
  91. fk_dict[ct_id].add(fk_val)
  92. instance_dict[ct_id] = instance
  93. ret_val = []
  94. for ct_id, fkeys in fk_dict.items():
  95. if ct_id in custom_queryset_dict:
  96. # Return values from the custom queryset, if provided.
  97. ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
  98. else:
  99. instance = instance_dict[ct_id]
  100. ct = self.get_content_type(id=ct_id, using=instance._state.db)
  101. qs = ct.model_class().objects.filter(pk__in=fkeys)
  102. if restrict_params:
  103. qs = qs.restrict(**restrict_params)
  104. ret_val.extend(qs)
  105. # For doing the join in Python, we have to match both the FK val and the
  106. # content type, so we use a callable that returns a (fk, class) pair.
  107. def gfk_key(obj):
  108. ct_id = getattr(obj, ct_attname)
  109. if ct_id is None:
  110. return None
  111. if model := self.get_content_type(
  112. id=ct_id, using=obj._state.db
  113. ).model_class():
  114. return (
  115. model._meta.pk.get_prep_value(getattr(obj, self.fk_field)),
  116. model,
  117. )
  118. return None
  119. return (
  120. ret_val,
  121. lambda obj: (obj.pk, obj.__class__),
  122. gfk_key,
  123. True,
  124. self.name,
  125. False,
  126. )
  127. class CounterCacheField(models.BigIntegerField):
  128. """
  129. Counter field to keep track of related model counts.
  130. """
  131. def __init__(self, to_model, to_field, *args, **kwargs):
  132. if not isinstance(to_model, str):
  133. raise TypeError(
  134. _("%s(%r) is invalid. to_model parameter to CounterCacheField must be "
  135. "a string in the format 'app.model'")
  136. % (
  137. self.__class__.__name__,
  138. to_model,
  139. )
  140. )
  141. if not isinstance(to_field, str):
  142. raise TypeError(
  143. _("%s(%r) is invalid. to_field parameter to CounterCacheField must be "
  144. "a string in the format 'field'")
  145. % (
  146. self.__class__.__name__,
  147. to_field,
  148. )
  149. )
  150. self.to_model_name = to_model
  151. self.to_field_name = to_field
  152. kwargs['default'] = kwargs.get('default', 0)
  153. kwargs['editable'] = False
  154. super().__init__(*args, **kwargs)
  155. def deconstruct(self):
  156. name, path, args, kwargs = super().deconstruct()
  157. kwargs["to_model"] = self.to_model_name
  158. kwargs["to_field"] = self.to_field_name
  159. return name, path, args, kwargs
  160. class GenericArrayForeignKey(FieldCacheMixin, models.Field):
  161. """
  162. Provide a generic many-to-many relation through an 2d array field
  163. """
  164. many_to_many = False
  165. many_to_one = False
  166. one_to_many = True
  167. one_to_one = False
  168. def __init__(self, field, for_concrete_model=True):
  169. super().__init__(editable=False)
  170. self.field = field
  171. self.for_concrete_model = for_concrete_model
  172. self.is_relation = True
  173. def contribute_to_class(self, cls, name, **kwargs):
  174. super().contribute_to_class(cls, name, private_only=True, **kwargs)
  175. # GenericArrayForeignKey is its own descriptor.
  176. setattr(cls, self.attname, self)
  177. @cached_property
  178. def cache_name(self):
  179. return self.name
  180. def get_cache_name(self):
  181. return self.cache_name
  182. def _get_ids(self, instance):
  183. return getattr(instance, self.field)
  184. def get_content_type_by_id(self, id=None, using=None):
  185. return ContentType.objects.db_manager(using).get_for_id(id)
  186. def get_content_type_of_obj(self, obj=None):
  187. return ContentType.objects.db_manager(obj._state.db).get_for_model(
  188. obj, for_concrete_model=self.for_concrete_model
  189. )
  190. def get_content_type_for_model(self, using=None, model=None):
  191. return ContentType.objects.db_manager(using).get_for_model(
  192. model, for_concrete_model=self.for_concrete_model
  193. )
  194. def get_prefetch_querysets(self, instances, querysets=None):
  195. custom_queryset_dict = {}
  196. if querysets is not None:
  197. for queryset in querysets:
  198. ct_id = self.get_content_type_for_model(
  199. model=queryset.query.model, using=queryset.db
  200. ).pk
  201. if ct_id in custom_queryset_dict:
  202. raise ValueError(
  203. "Only one queryset is allowed for each content type."
  204. )
  205. custom_queryset_dict[ct_id] = queryset
  206. # For efficiency, group the instances by content type and then do one
  207. # query per model
  208. fk_dict = defaultdict(set) # type id, db -> model ids
  209. for instance in instances:
  210. for step in self._get_ids(instance):
  211. for ct_id, fk_val in step:
  212. fk_dict[(ct_id, instance._state.db)].add(fk_val)
  213. rel_objects = []
  214. for (ct_id, db), fkeys in fk_dict.items():
  215. if ct_id in custom_queryset_dict:
  216. rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
  217. else:
  218. ct = self.get_content_type_by_id(id=ct_id, using=db)
  219. rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
  220. # reorganize objects to fix usage
  221. items = {
  222. (self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
  223. for rel_obj in rel_objects
  224. }
  225. lists = []
  226. lists_keys = {}
  227. for instance in instances:
  228. data = []
  229. lists.append(data)
  230. lists_keys[instance] = id(data)
  231. for step in self._get_ids(instance):
  232. nodes = []
  233. for ct, fk in step:
  234. if rel_obj := items.get((ct, fk, instance._state.db)):
  235. nodes.append(rel_obj)
  236. data.append(nodes)
  237. return (
  238. lists,
  239. lambda obj: id(obj),
  240. lambda obj: lists_keys[obj],
  241. True,
  242. self.cache_name,
  243. False,
  244. )
  245. def __get__(self, instance, cls=None):
  246. if instance is None:
  247. return self
  248. rel_objects = self.get_cached_value(instance, default=...)
  249. expected_ids = self._get_ids(instance)
  250. # we do not check if cache actual
  251. if rel_objects is not ...:
  252. return rel_objects
  253. # load value
  254. if expected_ids is None:
  255. self.set_cached_value(instance, rel_objects)
  256. return rel_objects
  257. data = []
  258. for step in self._get_ids(instance):
  259. rel_objects = []
  260. for ct_id, pk_val in step:
  261. ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
  262. try:
  263. rel_obj = ct.get_object_for_this_type(pk=pk_val)
  264. rel_objects.append(rel_obj)
  265. except ObjectDoesNotExist:
  266. pass
  267. data.append(rel_objects)
  268. self.set_cached_value(instance, data)
  269. return data