|
@@ -70,14 +70,24 @@ class RestrictedGenericForeignKey(GenericForeignKey):
|
|
|
# 1. Capture restrict_params from RestrictedPrefetch (hack)
|
|
# 1. Capture restrict_params from RestrictedPrefetch (hack)
|
|
|
# 2. If restrict_params is set, call restrict() on the queryset for
|
|
# 2. If restrict_params is set, call restrict() on the queryset for
|
|
|
# the related model
|
|
# the related model
|
|
|
- def get_prefetch_queryset(self, instances, queryset=None):
|
|
|
|
|
|
|
+ def get_prefetch_querysets(self, instances, querysets=None):
|
|
|
restrict_params = {}
|
|
restrict_params = {}
|
|
|
|
|
+ custom_queryset_dict = {}
|
|
|
|
|
|
|
|
# Compensate for the hack in RestrictedPrefetch
|
|
# Compensate for the hack in RestrictedPrefetch
|
|
|
- if type(queryset) is dict:
|
|
|
|
|
- restrict_params = queryset
|
|
|
|
|
- elif queryset is not None:
|
|
|
|
|
- raise ValueError(_("Custom queryset can't be used for this lookup."))
|
|
|
|
|
|
|
+ if type(querysets) is dict:
|
|
|
|
|
+ restrict_params = querysets
|
|
|
|
|
+
|
|
|
|
|
+ elif querysets is not None:
|
|
|
|
|
+ for queryset in querysets:
|
|
|
|
|
+ ct_id = self.get_content_type(
|
|
|
|
|
+ model=queryset.query.model, using=queryset.db
|
|
|
|
|
+ ).pk
|
|
|
|
|
+ if ct_id in custom_queryset_dict:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "Only one queryset is allowed for each content type."
|
|
|
|
|
+ )
|
|
|
|
|
+ custom_queryset_dict[ct_id] = queryset
|
|
|
|
|
|
|
|
# For efficiency, group the instances by content type and then do one
|
|
# For efficiency, group the instances by content type and then do one
|
|
|
# query per model
|
|
# query per model
|
|
@@ -100,15 +110,16 @@ class RestrictedGenericForeignKey(GenericForeignKey):
|
|
|
|
|
|
|
|
ret_val = []
|
|
ret_val = []
|
|
|
for ct_id, fkeys in fk_dict.items():
|
|
for ct_id, fkeys in fk_dict.items():
|
|
|
- instance = instance_dict[ct_id]
|
|
|
|
|
- ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
|
|
|
|
- if restrict_params:
|
|
|
|
|
- # Override the default behavior to call restrict() on each model's queryset
|
|
|
|
|
- qs = ct.model_class().objects.filter(pk__in=fkeys).restrict(**restrict_params)
|
|
|
|
|
- ret_val.extend(qs)
|
|
|
|
|
|
|
+ if ct_id in custom_queryset_dict:
|
|
|
|
|
+ # Return values from the custom queryset, if provided.
|
|
|
|
|
+ ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
|
|
|
else:
|
|
else:
|
|
|
- # Default behavior
|
|
|
|
|
- ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
|
|
|
|
|
|
|
+ instance = instance_dict[ct_id]
|
|
|
|
|
+ ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
|
|
|
|
+ qs = ct.model_class().objects.filter(pk__in=fkeys)
|
|
|
|
|
+ if restrict_params:
|
|
|
|
|
+ qs = qs.restrict(**restrict_params)
|
|
|
|
|
+ ret_val.extend(qs)
|
|
|
|
|
|
|
|
# For doing the join in Python, we have to match both the FK val and the
|
|
# For doing the join in Python, we have to match both the FK val and the
|
|
|
# content type, so we use a callable that returns a (fk, class) pair.
|
|
# content type, so we use a callable that returns a (fk, class) pair.
|