field.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from collections import defaultdict
  2. from django.contrib.contenttypes.models import ContentType
  3. from django.core.exceptions import ObjectDoesNotExist
  4. from django.db.models.fields import Field
  5. from django.db.models.fields.mixins import FieldCacheMixin
  6. from django.utils.functional import cached_property
  7. class GenericArrayForeignKey(FieldCacheMixin, Field):
  8. """
  9. Provide a generic many-to-many relation through an array field
  10. """
  11. many_to_many = True
  12. many_to_one = False
  13. one_to_many = False
  14. one_to_one = False
  15. def __init__(self, field, for_concrete_model=True):
  16. super().__init__(editable=False)
  17. self.field = field
  18. self.for_concrete_model = for_concrete_model
  19. self.is_relation = True
  20. def contribute_to_class(self, cls, name, **kwargs):
  21. super().contribute_to_class(cls, name, private_only=True, **kwargs)
  22. # GenericForeignKey is its own descriptor.
  23. setattr(cls, self.attname, self)
  24. @cached_property
  25. def cache_name(self):
  26. return self.name
  27. def get_cache_name(self):
  28. return self.cache_name
  29. def _get_ids(self, instance):
  30. return getattr(instance, self.field)
  31. def get_content_type_by_id(self, id=None, using=None):
  32. return ContentType.objects.db_manager(using).get_for_id(id)
  33. def get_content_type_of_obj(self, obj=None):
  34. return ContentType.objects.db_manager(obj._state.db).get_for_model(
  35. obj, for_concrete_model=self.for_concrete_model
  36. )
  37. def get_content_type_for_model(self, using=None, model=None):
  38. return ContentType.objects.db_manager(using).get_for_model(
  39. model, for_concrete_model=self.for_concrete_model
  40. )
  41. def get_prefetch_querysets(self, instances, querysets=None):
  42. custom_queryset_dict = {}
  43. if querysets is not None:
  44. for queryset in querysets:
  45. ct_id = self.get_content_type_for_model(
  46. model=queryset.query.model, using=queryset.db
  47. ).pk
  48. if ct_id in custom_queryset_dict:
  49. raise ValueError(
  50. "Only one queryset is allowed for each content type."
  51. )
  52. custom_queryset_dict[ct_id] = queryset
  53. # For efficiency, group the instances by content type and then do one
  54. # query per model
  55. fk_dict = defaultdict(set) # type id, db -> model ids
  56. for instance in instances:
  57. for step in self._get_ids(instance):
  58. for ct_id, fk_val in step:
  59. fk_dict[(ct_id, instance._state.db)].add(fk_val)
  60. rel_objects = []
  61. for (ct_id, db), fkeys in fk_dict.items():
  62. if ct_id in custom_queryset_dict:
  63. rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
  64. else:
  65. ct = self.get_content_type_by_id(id=ct_id, using=db)
  66. rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
  67. # reorganize objects to fix usage
  68. items = {
  69. (self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
  70. for rel_obj in rel_objects
  71. }
  72. lists = []
  73. lists_keys = {}
  74. for instance in instances:
  75. data = []
  76. lists.append(data)
  77. lists_keys[instance] = id(data)
  78. for step in self._get_ids(instance):
  79. nodes = []
  80. for ct, fk in step:
  81. if rel_obj := items.get((ct, fk, instance._state.db)):
  82. nodes.append(rel_obj)
  83. data.append(nodes)
  84. return (
  85. lists,
  86. lambda obj: id(obj),
  87. lambda obj: lists_keys[obj],
  88. True,
  89. self.cache_name,
  90. False,
  91. )
  92. def __get__(self, instance, cls=None):
  93. if instance is None:
  94. return self
  95. rel_objects = self.get_cached_value(instance, default=...)
  96. expected_ids = self._get_ids(instance)
  97. # we do not check if cache actual
  98. if rel_objects is not ...:
  99. return rel_objects
  100. # load value
  101. if expected_ids is None:
  102. self.set_cached_value(instance, rel_objects)
  103. return rel_objects
  104. data = []
  105. for step in self._get_ids(instance):
  106. rel_objects = []
  107. for ct_id, pk_val in step:
  108. ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
  109. try:
  110. rel_obj = ct.get_object_for_this_type(pk=pk_val)
  111. rel_objects.append(rel_obj)
  112. except ObjectDoesNotExist:
  113. pass
  114. data.append(rel_objects)
  115. self.set_cached_value(instance, data)
  116. return data