Просмотр исходного кода

Introduce PathContains lookup to allow filtering against objects in path directly

Jeremy Stretch 5 лет назад
Родитель
Сommit
56ee425227

+ 4 - 1
netbox/dcim/fields.py

@@ -1,11 +1,11 @@
 from django.contrib.postgres.fields import ArrayField
-from django.contrib.postgres.validators import ArrayMaxLengthValidator
 from django.core.exceptions import ValidationError
 from django.core.validators import MinValueValidator, MaxValueValidator
 from django.db import models
 from netaddr import AddrFormatError, EUI, mac_unix_expanded
 
 from ipam.constants import BGP_ASN_MAX, BGP_ASN_MIN
+from .lookups import PathContains
 
 
 class ASNField(models.BigIntegerField):
@@ -61,3 +61,6 @@ class PathField(ArrayField):
     def __init__(self, **kwargs):
         kwargs['base_field'] = models.CharField(max_length=40)
         super().__init__(**kwargs)
+
+
+PathField.register_lookup(PathContains)

+ 10 - 0
netbox/dcim/lookups.py

@@ -0,0 +1,10 @@
+from django.contrib.postgres.fields.array import ArrayContains
+
+from dcim.utils import object_to_path_node
+
+
+class PathContains(ArrayContains):
+
+    def get_prep_lookup(self):
+        self.rhs = [object_to_path_node(self.rhs)]
+        return super().get_prep_lookup()

+ 4 - 5
netbox/dcim/signals.py

@@ -7,7 +7,7 @@ from django.dispatch import receiver
 
 from .choices import CableStatusChoices
 from .models import Cable, CablePath, Device, PathEndpoint, VirtualChassis
-from .utils import object_to_path_node, trace_path
+from .utils import trace_path
 
 
 def create_cablepath(node):
@@ -24,8 +24,7 @@ def rebuild_paths(obj):
     """
     Rebuild all CablePaths which traverse the specified node
     """
-    node = object_to_path_node(obj)
-    cable_paths = CablePath.objects.filter(path__contains=[node])
+    cable_paths = CablePath.objects.filter(path__contains=obj)
 
     with transaction.atomic():
         for cp in cable_paths:
@@ -86,7 +85,7 @@ def update_connected_endpoints(instance, created, **kwargs):
         # may change in the future.) However, we do need to capture status changes and update
         # any CablePaths accordingly.
         if instance.status != CableStatusChoices.STATUS_CONNECTED:
-            CablePath.objects.filter(path__contains=[object_to_path_node(instance)]).update(is_active=False)
+            CablePath.objects.filter(path__contains=instance).update(is_active=False)
         else:
             rebuild_paths(instance)
 
@@ -109,7 +108,7 @@ def nullify_connected_endpoints(instance, **kwargs):
         instance.termination_b.save()
 
     # Delete and retrace any dependent cable paths
-    for cablepath in CablePath.objects.filter(path__contains=[object_to_path_node(instance)]):
+    for cablepath in CablePath.objects.filter(path__contains=instance):
         path, destination, is_active = trace_path(cablepath.origin)
         if path:
             CablePath.objects.filter(pk=cablepath.pk).update(

+ 2 - 2
netbox/dcim/tests/test_cablepaths.py

@@ -4,7 +4,7 @@ from django.test import TestCase
 from circuits.models import *
 from dcim.choices import CableStatusChoices
 from dcim.models import *
-from dcim.utils import objects_to_path
+from dcim.utils import object_to_path_node
 
 
 class CablePathTestCase(TestCase):
@@ -146,7 +146,7 @@ class CablePathTestCase(TestCase):
             kwargs['destination_type__isnull'] = True
             kwargs['destination_id__isnull'] = True
         if path is not None:
-            kwargs['path'] = objects_to_path(*path)
+            kwargs['path'] = [object_to_path_node(obj) for obj in path]
         if is_active is not None:
             kwargs['is_active'] = is_active
         if msg is None:

+ 0 - 4
netbox/dcim/utils.py

@@ -8,10 +8,6 @@ def object_to_path_node(obj):
     return f'{obj._meta.model_name}:{obj.pk}'
 
 
-def objects_to_path(*obj_list):
-    return [object_to_path_node(obj) for obj in obj_list]
-
-
 def path_node_to_object(repr):
     model_name, object_id = repr.split(':')
     model_class = ContentType.objects.get(model=model_name).model_class()

+ 1 - 4
netbox/dcim/views.py

@@ -38,7 +38,6 @@ from .models import (
     PowerPort, PowerPortTemplate, Rack, RackGroup, RackReservation, RackRole, RearPort, RearPortTemplate, Region, Site,
     VirtualChassis,
 )
-from .utils import object_to_path_node
 
 
 class BulkDisconnectView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
@@ -1974,9 +1973,7 @@ class PathTraceView(ObjectView):
             path = obj._path
         # Otherwise, find all CablePaths which traverse the specified object
         else:
-            related_paths = CablePath.objects.filter(
-                path__contains=[object_to_path_node(obj)]
-            ).prefetch_related('origin')
+            related_paths = CablePath.objects.filter(path__contains=obj).prefetch_related('origin')
             # Check for specification of a particular path (when tracing pass-through ports)
             try:
                 path_id = int(request.GET.get('cablepath_id'))