Kaynağa Gözat

Merge pull request #18826 from Tishka17/fix/generic_prefetch_4.2

Prefetch interface data for REST API on netbox 4.2
bctiemann 11 ay önce
ebeveyn
işleme
b1e7d7c76b

+ 13 - 1
netbox/dcim/api/views.py

@@ -1,3 +1,4 @@
+from django.contrib.contenttypes.prefetch import GenericPrefetch
 from django.http import Http404, HttpResponse
 from django.shortcuts import get_object_or_404
 from drf_spectacular.types import OpenApiTypes
@@ -442,7 +443,18 @@ class PowerOutletViewSet(PathEndpointMixin, NetBoxModelViewSet):
 
 class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet):
     queryset = Interface.objects.prefetch_related(
-        '_path', 'cable__terminations',
+        GenericPrefetch(
+            "cable__terminations__termination",
+            [
+                Interface.objects.select_related("device", "cable"),
+            ],
+        ),
+        GenericPrefetch(
+            "_path__path_objects",
+            [
+                Interface.objects.select_related("device", "cable"),
+            ],
+        ),
         'l2vpn_terminations',  # Referenced by InterfaceSerializer.l2vpn_termination
         'ip_addresses',  # Referenced by Interface.count_ipaddresses()
         'fhrp_group_assignments',  # Referenced by Interface.count_fhrp_groups()

+ 11 - 45
netbox/dcim/models/cables.py

@@ -1,5 +1,4 @@
 import itertools
-from collections import defaultdict
 
 from django.contrib.contenttypes.fields import GenericForeignKey
 from django.core.exceptions import ValidationError
@@ -16,7 +15,7 @@ from dcim.utils import decompile_path_node, object_to_path_node
 from netbox.models import ChangeLoggedModel, PrimaryModel
 from utilities.conversion import to_meters
 from utilities.exceptions import AbortRequest
-from utilities.fields import ColorField
+from utilities.fields import ColorField, GenericArrayForeignKey
 from utilities.querysets import RestrictedQuerySet
 from wireless.models import WirelessLink
 from .device_components import FrontPort, RearPort, PathEndpoint
@@ -494,13 +493,16 @@ class CablePath(models.Model):
             return ObjectType.objects.get_for_id(ct_id)
 
     @property
-    def path_objects(self):
-        """
-        Cache and return the complete path as lists of objects, derived from their annotation within the path.
-        """
-        if not hasattr(self, '_path_objects'):
-            self._path_objects = self._get_path()
-        return self._path_objects
+    def _path_decompiled(self):
+        res = []
+        for step in self.path:
+            nodes = []
+            for node in step:
+                nodes.append(decompile_path_node(node))
+            res.append(nodes)
+        return res
+
+    path_objects = GenericArrayForeignKey("_path_decompiled")
 
     @property
     def origins(self):
@@ -757,42 +759,6 @@ class CablePath(models.Model):
             self.delete()
     retrace.alters_data = True
 
-    def _get_path(self):
-        """
-        Return the path as a list of prefetched objects.
-        """
-        # Compile a list of IDs to prefetch for each type of model in the path
-        to_prefetch = defaultdict(list)
-        for node in self._nodes:
-            ct_id, object_id = decompile_path_node(node)
-            to_prefetch[ct_id].append(object_id)
-
-        # Prefetch path objects using one query per model type. Prefetch related devices where appropriate.
-        prefetched = {}
-        for ct_id, object_ids in to_prefetch.items():
-            model_class = ObjectType.objects.get_for_id(ct_id).model_class()
-            queryset = model_class.objects.filter(pk__in=object_ids)
-            if hasattr(model_class, 'device'):
-                queryset = queryset.prefetch_related('device')
-            prefetched[ct_id] = {
-                obj.id: obj for obj in queryset
-            }
-
-        # Replicate the path using the prefetched objects.
-        path = []
-        for step in self.path:
-            nodes = []
-            for node in step:
-                ct_id, object_id = decompile_path_node(node)
-                try:
-                    nodes.append(prefetched[ct_id][object_id])
-                except KeyError:
-                    # Ignore stale (deleted) object IDs
-                    pass
-            path.append(nodes)
-
-        return path
-
     def get_cable_ids(self):
         """
         Return all Cable IDs within the path.

+ 5 - 2
netbox/dcim/models/device_components.py

@@ -184,8 +184,11 @@ class CabledObjectModel(models.Model):
     @cached_property
     def link_peers(self):
         if self.cable:
-            peers = self.cable.terminations.exclude(cable_end=self.cable_end).prefetch_related('termination')
-            return [peer.termination for peer in peers]
+            return [
+                peer.termination
+                for peer in self.cable.terminations.all()
+                if peer.cable_end != self.cable_end
+            ]
         return []
 
     @property

+ 132 - 0
netbox/utilities/fields.py

@@ -1,7 +1,11 @@
 from collections import defaultdict
 
 from django.contrib.contenttypes.fields import GenericForeignKey
+from django.contrib.contenttypes.models import ContentType
+from django.core.exceptions import ObjectDoesNotExist
 from django.db import models
+from django.db.models.fields.mixins import FieldCacheMixin
+from django.utils.functional import cached_property
 from django.utils.safestring import mark_safe
 from django.utils.translation import gettext_lazy as _
 
@@ -11,6 +15,7 @@ from .validators import ColorValidator
 __all__ = (
     'ColorField',
     'CounterCacheField',
+    'GenericArrayForeignKey',
     'NaturalOrderingField',
     'RestrictedGenericForeignKey',
 )
@@ -186,3 +191,130 @@ class CounterCacheField(models.BigIntegerField):
         kwargs["to_model"] = self.to_model_name
         kwargs["to_field"] = self.to_field_name
         return name, path, args, kwargs
+
+
+class GenericArrayForeignKey(FieldCacheMixin, models.Field):
+    """
+    Provide a generic many-to-many relation through an 2d array field
+    """
+
+    many_to_many = True
+    many_to_one = False
+    one_to_many = False
+    one_to_one = False
+
+    def __init__(self, field, for_concrete_model=True):
+        super().__init__(editable=False)
+        self.field = field
+        self.for_concrete_model = for_concrete_model
+        self.is_relation = True
+
+    def contribute_to_class(self, cls, name, **kwargs):
+        super().contribute_to_class(cls, name, private_only=True, **kwargs)
+        # GenericArrayForeignKey is its own descriptor.
+        setattr(cls, self.attname, self)
+
+    @cached_property
+    def cache_name(self):
+        return self.name
+
+    def get_cache_name(self):
+        return self.cache_name
+
+    def _get_ids(self, instance):
+        return getattr(instance, self.field)
+
+    def get_content_type_by_id(self, id=None, using=None):
+        return ContentType.objects.db_manager(using).get_for_id(id)
+
+    def get_content_type_of_obj(self, obj=None):
+        return ContentType.objects.db_manager(obj._state.db).get_for_model(
+            obj, for_concrete_model=self.for_concrete_model
+        )
+
+    def get_content_type_for_model(self, using=None, model=None):
+        return ContentType.objects.db_manager(using).get_for_model(
+            model, for_concrete_model=self.for_concrete_model
+        )
+
+    def get_prefetch_querysets(self, instances, querysets=None):
+        custom_queryset_dict = {}
+        if querysets is not None:
+            for queryset in querysets:
+                ct_id = self.get_content_type_for_model(
+                    model=queryset.query.model, using=queryset.db
+                ).pk
+                if ct_id in custom_queryset_dict:
+                    raise ValueError(
+                        "Only one queryset is allowed for each content type."
+                    )
+                custom_queryset_dict[ct_id] = queryset
+
+        # For efficiency, group the instances by content type and then do one
+        # query per model
+        fk_dict = defaultdict(set)  # type id, db -> model ids
+        for instance in instances:
+            for step in self._get_ids(instance):
+                for ct_id, fk_val in step:
+                    fk_dict[(ct_id, instance._state.db)].add(fk_val)
+
+        rel_objects = []
+        for (ct_id, db), fkeys in fk_dict.items():
+            if ct_id in custom_queryset_dict:
+                rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
+            else:
+                ct = self.get_content_type_by_id(id=ct_id, using=db)
+                rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
+
+        # reorganize objects to fix usage
+        items = {
+            (self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
+            for rel_obj in rel_objects
+        }
+        lists = []
+        lists_keys = {}
+        for instance in instances:
+            data = []
+            lists.append(data)
+            lists_keys[instance] = id(data)
+            for step in self._get_ids(instance):
+                nodes = []
+                for ct, fk in step:
+                    if rel_obj := items.get((ct, fk, instance._state.db)):
+                        nodes.append(rel_obj)
+                data.append(nodes)
+
+        return (
+            lists,
+            lambda obj: id(obj),
+            lambda obj: lists_keys[obj],
+            True,
+            self.cache_name,
+            False,
+        )
+
+    def __get__(self, instance, cls=None):
+        if instance is None:
+            return self
+        rel_objects = self.get_cached_value(instance, default=...)
+        expected_ids = self._get_ids(instance)
+        # we do not check if cache actual
+        if rel_objects is not ...:
+            return rel_objects
+        # load value
+        if expected_ids is None:
+            self.set_cached_value(instance, rel_objects)
+            return rel_objects
+        data = []
+        for step in self._get_ids(instance):
+            rel_objects = []
+            for ct_id, pk_val in step:
+                ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
+                try:
+                    rel_obj = ct.get_object_for_this_type(pk=pk_val)
+                    rel_objects.append(rel_obj)
+                except ObjectDoesNotExist:
+                    pass
+            data.append(rel_objects)
+        self.set_cached_value(instance, data)
+        return data