Andrey Tikhonov 1 год назад
Родитель
Сommit
8dc1d68aee
3 измененных файлов с 176 добавлено и 2 удалено
  1. 19 1
      netbox/dcim/api/views.py
  2. 14 1
      netbox/dcim/models/cables.py
  3. 143 0
      netbox/utilities/generics/field.py

+ 19 - 1
netbox/dcim/api/views.py

@@ -1,3 +1,5 @@
+from django.contrib.contenttypes.prefetch import GenericPrefetch
+from django.db.models import Prefetch
 from django.http import Http404, HttpResponse
 from django.http import Http404, HttpResponse
 from django.shortcuts import get_object_or_404
 from django.shortcuts import get_object_or_404
 from drf_spectacular.types import OpenApiTypes
 from drf_spectacular.types import OpenApiTypes
@@ -432,7 +434,23 @@ class PowerOutletViewSet(PathEndpointMixin, NetBoxModelViewSet):
 
 
 class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet):
 class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet):
     queryset = Interface.objects.prefetch_related(
     queryset = Interface.objects.prefetch_related(
-        '_path', 'cable__terminations',
+        # '_path',
+        # 'cable__terminations',
+        GenericPrefetch(
+            "cable__terminations__termination",
+            [
+                Interface.objects.prefetch_related("device"),
+            ],
+        ),
+        Prefetch(
+            "_path",
+            CablePath.objects.prefetch_related(
+                GenericPrefetch("path_objects", [
+                    Interface.objects.prefetch_related("device"),
+                    Cable.objects.prefetch_related("terminations"),
+                ]),
+            )
+        ),
         'l2vpn_terminations',  # Referenced by InterfaceSerializer.l2vpn_termination
         'l2vpn_terminations',  # Referenced by InterfaceSerializer.l2vpn_termination
         'ip_addresses',  # Referenced by Interface.count_ipaddresses()
         'ip_addresses',  # Referenced by Interface.count_ipaddresses()
         'fhrp_group_assignments',  # Referenced by Interface.count_fhrp_groups()
         'fhrp_group_assignments',  # Referenced by Interface.count_fhrp_groups()

+ 14 - 1
netbox/dcim/models/cables.py

@@ -17,6 +17,7 @@ from dcim.utils import decompile_path_node, object_to_path_node
 from netbox.models import ChangeLoggedModel, PrimaryModel
 from netbox.models import ChangeLoggedModel, PrimaryModel
 from utilities.conversion import to_meters
 from utilities.conversion import to_meters
 from utilities.fields import ColorField
 from utilities.fields import ColorField
+from utilities.generics.field import GenericArrayForeignKey
 from utilities.querysets import RestrictedQuerySet
 from utilities.querysets import RestrictedQuerySet
 from wireless.models import WirelessLink
 from wireless.models import WirelessLink
 from .device_components import FrontPort, RearPort, PathEndpoint
 from .device_components import FrontPort, RearPort, PathEndpoint
@@ -490,7 +491,7 @@ class CablePath(models.Model):
             return ObjectType.objects.get_for_id(ct_id)
             return ObjectType.objects.get_for_id(ct_id)
 
 
     @property
     @property
-    def path_objects(self):
+    def path_objects_old(self):
         """
         """
         Cache and return the complete path as lists of objects, derived from their annotation within the path.
         Cache and return the complete path as lists of objects, derived from their annotation within the path.
         """
         """
@@ -498,6 +499,18 @@ class CablePath(models.Model):
             self._path_objects = self._get_path()
             self._path_objects = self._get_path()
         return self._path_objects
         return self._path_objects
 
 
+    @property
+    def _path_decompiled(self):
+        res = []
+        for step in self.path:
+            nodes = []
+            for node in step:
+                nodes.append(decompile_path_node(node))
+            res.append(nodes)
+        return res
+
+    path_objects = GenericArrayForeignKey("_path_decompiled")
+
     @property
     @property
     def origins(self):
     def origins(self):
         """
         """

+ 143 - 0
netbox/utilities/generics/field.py

@@ -0,0 +1,143 @@
+from collections import defaultdict
+
+from django.contrib.contenttypes.models import ContentType
+from django.core.exceptions import ObjectDoesNotExist
+from django.db.models.fields import Field
+from django.db.models.fields.mixins import FieldCacheMixin
+from django.utils.functional import cached_property
+
+
+class GenericArrayForeignKey(FieldCacheMixin, Field):
+    """
+    Provide a generic many-to-many relation through an array field
+    """
+
+    many_to_many = True
+    many_to_one = False
+    one_to_many = False
+    one_to_one = False
+
+    def __init__(self, field, for_concrete_model=True):
+        super().__init__(editable=False)
+        self.field = field
+        self.for_concrete_model = for_concrete_model
+        self.is_relation = True
+
+    def contribute_to_class(self, cls, name, **kwargs):
+        super().contribute_to_class(cls, name, private_only=True, **kwargs)
+        # GenericForeignKey is its own descriptor.
+        setattr(cls, self.attname, self)
+
+    @cached_property
+    def cache_name(self):
+        return self.name
+
+    def get_cache_name(self):
+        return self.cache_name
+
+    def _get_ids(self, instance):
+        return getattr(instance, self.field)
+
+    def get_content_type_by_id(self, id=None, using=None):
+        return ContentType.objects.db_manager(using).get_for_id(id)
+
+    def get_content_type_of_obj(self, obj=None):
+        return ContentType.objects.db_manager(obj._state.db).get_for_model(
+            obj, for_concrete_model=self.for_concrete_model
+        )
+
+    def get_content_type_for_model(self, using=None, model=None):
+        return ContentType.objects.db_manager(using).get_for_model(
+            model, for_concrete_model=self.for_concrete_model
+        )
+
+    def get_prefetch_querysets(self, instances, querysets=None):
+        custom_queryset_dict = {}
+        if querysets is not None:
+            for queryset in querysets:
+                ct_id = self.get_content_type_for_model(
+                    model=queryset.query.model, using=queryset.db
+                ).pk
+                if ct_id in custom_queryset_dict:
+                    raise ValueError(
+                        "Only one queryset is allowed for each content type."
+                    )
+                custom_queryset_dict[ct_id] = queryset
+
+        # For efficiency, group the instances by content type and then do one
+        # query per model
+        fk_dict = defaultdict(set)  # type id, db -> model ids
+        for instance in instances:
+            for step in self._get_ids(instance):
+                for ct_id, fk_val in step:
+                    fk_dict[(ct_id, instance._state.db)].add(fk_val)
+
+        rel_objects = []
+        for (ct_id, db), fkeys in fk_dict.items():
+            if ct_id in custom_queryset_dict:
+                rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
+            else:
+                ct = self.get_content_type_by_id(id=ct_id, using=db)
+                rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
+
+        # reorganize objects to fix usage
+        items = {
+            (self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
+            for rel_obj in rel_objects
+        }
+        lists = []
+        lists_keys = {}
+        for instance in instances:
+            data = []
+            lists.append(data)
+            lists_keys[instance] = id(data)
+            for step in self._get_ids(instance):
+                nodes = []
+                for ct, fk in step:
+                    if rel_obj := items.get((ct, fk, instance._state.db)):
+                        nodes.append(rel_obj)
+                data.append(nodes)
+
+        return (
+            lists,
+            lambda obj: id(obj),
+            lambda obj: lists_keys[obj],
+            True,
+            self.cache_name,
+            False,
+        )
+
+    def __get__(self, instance, cls=None):
+        if instance is None:
+            return self
+        rel_objects = self.get_cached_value(instance, default=None)
+        expected_ids = self._get_ids(instance)
+        # check cache actual
+        if rel_objects is not None:
+            actual = [
+                [
+                    (self.get_content_type_of_obj(obj=item).id, item.pk)
+                    for item in step
+                ]
+                for step in rel_objects
+            ]
+            if expected_ids == actual:
+                return rel_objects
+        # reload value
+        if expected_ids is None:
+            self.set_cached_value(instance, rel_objects)
+            return rel_objects
+        data = []
+        for step in self._get_ids(instance):
+            rel_objects = []
+            for ct_id, pk_val in step:
+                ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
+                try:
+                    rel_obj = ct.get_object_for_this_type(pk=pk_val)
+                    rel_objects.append(rel_obj)
+                except ObjectDoesNotExist:
+                    pass
+            data.append(rel_objects)
+        self.set_cached_value(instance, data)
+        return data
+