瀏覽代碼

Introduce PathContains lookup to allow filtering against objects in path directly

Jeremy Stretch 5 年之前
父節點
當前提交
56ee425227
共有 6 個文件被更改,包括 21 次插入16 次删除
  1. 4 1
      netbox/dcim/fields.py
  2. 10 0
      netbox/dcim/lookups.py
  3. 4 5
      netbox/dcim/signals.py
  4. 2 2
      netbox/dcim/tests/test_cablepaths.py
  5. 0 4
      netbox/dcim/utils.py
  6. 1 4
      netbox/dcim/views.py

+ 4 - 1
netbox/dcim/fields.py

@@ -1,11 +1,11 @@
 from django.contrib.postgres.fields import ArrayField
 from django.contrib.postgres.fields import ArrayField
-from django.contrib.postgres.validators import ArrayMaxLengthValidator
 from django.core.exceptions import ValidationError
 from django.core.exceptions import ValidationError
 from django.core.validators import MinValueValidator, MaxValueValidator
 from django.core.validators import MinValueValidator, MaxValueValidator
 from django.db import models
 from django.db import models
 from netaddr import AddrFormatError, EUI, mac_unix_expanded
 from netaddr import AddrFormatError, EUI, mac_unix_expanded
 
 
 from ipam.constants import BGP_ASN_MAX, BGP_ASN_MIN
 from ipam.constants import BGP_ASN_MAX, BGP_ASN_MIN
+from .lookups import PathContains
 
 
 
 
 class ASNField(models.BigIntegerField):
 class ASNField(models.BigIntegerField):
@@ -61,3 +61,6 @@ class PathField(ArrayField):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         kwargs['base_field'] = models.CharField(max_length=40)
         kwargs['base_field'] = models.CharField(max_length=40)
         super().__init__(**kwargs)
         super().__init__(**kwargs)
+
+
+PathField.register_lookup(PathContains)

+ 10 - 0
netbox/dcim/lookups.py

@@ -0,0 +1,10 @@
+from django.contrib.postgres.fields.array import ArrayContains
+
+from dcim.utils import object_to_path_node
+
+
+class PathContains(ArrayContains):
+
+    def get_prep_lookup(self):
+        self.rhs = [object_to_path_node(self.rhs)]
+        return super().get_prep_lookup()

+ 4 - 5
netbox/dcim/signals.py

@@ -7,7 +7,7 @@ from django.dispatch import receiver
 
 
 from .choices import CableStatusChoices
 from .choices import CableStatusChoices
 from .models import Cable, CablePath, Device, PathEndpoint, VirtualChassis
 from .models import Cable, CablePath, Device, PathEndpoint, VirtualChassis
-from .utils import object_to_path_node, trace_path
+from .utils import trace_path
 
 
 
 
 def create_cablepath(node):
 def create_cablepath(node):
@@ -24,8 +24,7 @@ def rebuild_paths(obj):
     """
     """
     Rebuild all CablePaths which traverse the specified node
     Rebuild all CablePaths which traverse the specified node
     """
     """
-    node = object_to_path_node(obj)
-    cable_paths = CablePath.objects.filter(path__contains=[node])
+    cable_paths = CablePath.objects.filter(path__contains=obj)
 
 
     with transaction.atomic():
     with transaction.atomic():
         for cp in cable_paths:
         for cp in cable_paths:
@@ -86,7 +85,7 @@ def update_connected_endpoints(instance, created, **kwargs):
         # may change in the future.) However, we do need to capture status changes and update
         # may change in the future.) However, we do need to capture status changes and update
         # any CablePaths accordingly.
         # any CablePaths accordingly.
         if instance.status != CableStatusChoices.STATUS_CONNECTED:
         if instance.status != CableStatusChoices.STATUS_CONNECTED:
-            CablePath.objects.filter(path__contains=[object_to_path_node(instance)]).update(is_active=False)
+            CablePath.objects.filter(path__contains=instance).update(is_active=False)
         else:
         else:
             rebuild_paths(instance)
             rebuild_paths(instance)
 
 
@@ -109,7 +108,7 @@ def nullify_connected_endpoints(instance, **kwargs):
         instance.termination_b.save()
         instance.termination_b.save()
 
 
     # Delete and retrace any dependent cable paths
     # Delete and retrace any dependent cable paths
-    for cablepath in CablePath.objects.filter(path__contains=[object_to_path_node(instance)]):
+    for cablepath in CablePath.objects.filter(path__contains=instance):
         path, destination, is_active = trace_path(cablepath.origin)
         path, destination, is_active = trace_path(cablepath.origin)
         if path:
         if path:
             CablePath.objects.filter(pk=cablepath.pk).update(
             CablePath.objects.filter(pk=cablepath.pk).update(

+ 2 - 2
netbox/dcim/tests/test_cablepaths.py

@@ -4,7 +4,7 @@ from django.test import TestCase
 from circuits.models import *
 from circuits.models import *
 from dcim.choices import CableStatusChoices
 from dcim.choices import CableStatusChoices
 from dcim.models import *
 from dcim.models import *
-from dcim.utils import objects_to_path
+from dcim.utils import object_to_path_node
 
 
 
 
 class CablePathTestCase(TestCase):
 class CablePathTestCase(TestCase):
@@ -146,7 +146,7 @@ class CablePathTestCase(TestCase):
             kwargs['destination_type__isnull'] = True
             kwargs['destination_type__isnull'] = True
             kwargs['destination_id__isnull'] = True
             kwargs['destination_id__isnull'] = True
         if path is not None:
         if path is not None:
-            kwargs['path'] = objects_to_path(*path)
+            kwargs['path'] = [object_to_path_node(obj) for obj in path]
         if is_active is not None:
         if is_active is not None:
             kwargs['is_active'] = is_active
             kwargs['is_active'] = is_active
         if msg is None:
         if msg is None:

+ 0 - 4
netbox/dcim/utils.py

@@ -8,10 +8,6 @@ def object_to_path_node(obj):
     return f'{obj._meta.model_name}:{obj.pk}'
     return f'{obj._meta.model_name}:{obj.pk}'
 
 
 
 
-def objects_to_path(*obj_list):
-    return [object_to_path_node(obj) for obj in obj_list]
-
-
 def path_node_to_object(repr):
 def path_node_to_object(repr):
     model_name, object_id = repr.split(':')
     model_name, object_id = repr.split(':')
     model_class = ContentType.objects.get(model=model_name).model_class()
     model_class = ContentType.objects.get(model=model_name).model_class()

+ 1 - 4
netbox/dcim/views.py

@@ -38,7 +38,6 @@ from .models import (
     PowerPort, PowerPortTemplate, Rack, RackGroup, RackReservation, RackRole, RearPort, RearPortTemplate, Region, Site,
     PowerPort, PowerPortTemplate, Rack, RackGroup, RackReservation, RackRole, RearPort, RearPortTemplate, Region, Site,
     VirtualChassis,
     VirtualChassis,
 )
 )
-from .utils import object_to_path_node
 
 
 
 
 class BulkDisconnectView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 class BulkDisconnectView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
@@ -1974,9 +1973,7 @@ class PathTraceView(ObjectView):
             path = obj._path
             path = obj._path
         # Otherwise, find all CablePaths which traverse the specified object
         # Otherwise, find all CablePaths which traverse the specified object
         else:
         else:
-            related_paths = CablePath.objects.filter(
-                path__contains=[object_to_path_node(obj)]
-            ).prefetch_related('origin')
+            related_paths = CablePath.objects.filter(path__contains=obj).prefetch_related('origin')
             # Check for specification of a particular path (when tracing pass-through ports)
             # Check for specification of a particular path (when tracing pass-through ports)
             try:
             try:
                 path_id = int(request.GET.get('cablepath_id'))
                 path_id = int(request.GET.get('cablepath_id'))