|
|
@@ -3,10 +3,12 @@ from collections import defaultdict
|
|
|
from django.conf import settings
|
|
|
from django.contrib.contenttypes.models import ContentType
|
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
|
-from django.db.models import F, Window
|
|
|
+from django.db.models import F, Window, Q
|
|
|
from django.db.models.functions import window
|
|
|
from django.db.models.signals import post_delete, post_save
|
|
|
from django.utils.module_loading import import_string
|
|
|
+import netaddr
|
|
|
+from netaddr.core import AddrFormatError
|
|
|
|
|
|
from extras.models import CachedValue, CustomField
|
|
|
from netbox.registry import registry
|
|
|
@@ -52,11 +54,11 @@ class SearchBackend:
|
|
|
"""
|
|
|
raise NotImplementedError
|
|
|
|
|
|
- def caching_handler(self, sender, instance, **kwargs):
|
|
|
+ def caching_handler(self, sender, instance, created, **kwargs):
|
|
|
"""
|
|
|
Receiver for the post_save signal, responsible for caching object creation/changes.
|
|
|
"""
|
|
|
- self.cache(instance)
|
|
|
+ self.cache(instance, remove_existing=not created)
|
|
|
|
|
|
def removal_handler(self, sender, instance, **kwargs):
|
|
|
"""
|
|
|
@@ -78,7 +80,13 @@ class SearchBackend:
|
|
|
|
|
|
def clear(self, object_types=None):
|
|
|
"""
|
|
|
- Delete *all* cached data.
|
|
|
+ Delete *all* cached data (optionally filtered by object type).
|
|
|
+ """
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+ def count(self, object_types=None):
|
|
|
+ """
|
|
|
+ Return a count of all cache entries (optionally filtered by object type).
|
|
|
"""
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@@ -95,18 +103,24 @@ class CachedValueSearchBackend(SearchBackend):
|
|
|
|
|
|
def search(self, value, user=None, object_types=None, lookup=DEFAULT_LOOKUP_TYPE):
|
|
|
|
|
|
- # Define the search parameters
|
|
|
- params = {
|
|
|
- f'value__{lookup}': value
|
|
|
- }
|
|
|
+ query_filter = Q(**{f'value__{lookup}': value})
|
|
|
+
|
|
|
+ if object_types:
|
|
|
+ query_filter &= Q(object_type__in=object_types)
|
|
|
+
|
|
|
if lookup in (LookupTypes.STARTSWITH, LookupTypes.ENDSWITH):
|
|
|
# Partial string matches are valid only on string values
|
|
|
- params['type'] = FieldTypes.STRING
|
|
|
- if object_types:
|
|
|
- params['object_type__in'] = object_types
|
|
|
+ query_filter &= Q(type=FieldTypes.STRING)
|
|
|
+
|
|
|
+ if lookup == LookupTypes.PARTIAL:
|
|
|
+ try:
|
|
|
+ address = str(netaddr.IPNetwork(value.strip()).cidr)
|
|
|
+ query_filter |= Q(type=FieldTypes.CIDR) & Q(value__net_contains_or_equals=address)
|
|
|
+ except (AddrFormatError, ValueError):
|
|
|
+ pass
|
|
|
|
|
|
# Construct the base queryset to retrieve matching results
|
|
|
- queryset = CachedValue.objects.filter(**params).annotate(
|
|
|
+ queryset = CachedValue.objects.filter(query_filter).annotate(
|
|
|
# Annotate the rank of each result for its object according to its weight
|
|
|
row_number=Window(
|
|
|
expression=window.RowNumber(),
|
|
|
@@ -210,6 +224,12 @@ class CachedValueSearchBackend(SearchBackend):
|
|
|
# Call _raw_delete() on the queryset to avoid first loading instances into memory
|
|
|
return qs._raw_delete(using=qs.db)
|
|
|
|
|
|
+ def count(self, object_types=None):
|
|
|
+ qs = CachedValue.objects.all()
|
|
|
+ if object_types:
|
|
|
+ qs = qs.filter(object_type__in=object_types)
|
|
|
+ return qs.count()
|
|
|
+
|
|
|
@property
|
|
|
def size(self):
|
|
|
return CachedValue.objects.count()
|