Explorar o código

Introduce is_taggable utility function for identifying taggable models

Jeremy Stretch %!s(int64=6) %!d(string=hai) anos
pai
achega
ce4a5a38a3

+ 2 - 1
netbox/extras/middleware.py

@@ -9,6 +9,7 @@ from django.db.models.signals import pre_delete, post_save
 from django.utils import timezone
 from django_prometheus.models import model_deletes, model_inserts, model_updates
 
+from extras.utils import is_taggable
 from utilities.querysets import DummyQuerySet
 from .choices import ObjectChangeActionChoices
 from .models import ObjectChange
@@ -41,7 +42,7 @@ def handle_deleted_object(sender, instance, **kwargs):
     copy = deepcopy(instance)
 
     # Preserve tags
-    if hasattr(instance, 'tags'):
+    if is_taggable(instance):
         copy.tags = DummyQuerySet(instance.tags.all())
 
     # Queue the copy of the object for processing once the request completes

+ 15 - 0
netbox/extras/utils.py

@@ -0,0 +1,15 @@
+from taggit.managers import _TaggableManager
+from utilities.querysets import DummyQuerySet
+
+
+def is_taggable(obj):
+    """
+    Return True if the instance can have Tags assigned to it; False otherwise.
+    """
+    if hasattr(obj, 'tags'):
+        if issubclass(obj.tags.__class__, _TaggableManager):
+            return True
+        # TaggableManager has been replaced with a DummyQuerySet prior to object deletion
+        if isinstance(obj.tags, DummyQuerySet):
+            return True
+    return False

+ 3 - 2
netbox/utilities/utils.py

@@ -6,6 +6,7 @@ from django.core.serializers import serialize
 from django.db.models import Count, OuterRef, Subquery
 
 from dcim.choices import CableLengthUnitChoices
+from extras.utils import is_taggable
 
 
 def csv_format(data):
@@ -103,7 +104,7 @@ def serialize_object(obj, extra=None):
         }
 
     # Include any tags
-    if hasattr(obj, 'tags'):
+    if is_taggable(obj):
         data['tags'] = [tag.name for tag in obj.tags.all()]
 
     # Append any extra data
@@ -201,7 +202,7 @@ def prepare_cloned_fields(instance):
             params[field_name] = field_value
 
         # Copy tags
-        if hasattr(instance, 'tags'):
+        if is_taggable(instance):
             params['tags'] = ','.join([t.name for t in instance.tags.all()])
 
     # Concatenate parameters into a URL query string

+ 2 - 1
netbox/utilities/views.py

@@ -24,6 +24,7 @@ from django_tables2 import RequestConfig
 
 from extras.models import CustomField, CustomFieldValue, ExportTemplate
 from extras.querysets import CustomFieldQueryset
+from extras.utils import is_taggable
 from utilities.exceptions import AbortTransaction
 from utilities.forms import BootstrapMixin, CSVDataField
 from utilities.utils import csv_format, prepare_cloned_fields
@@ -144,7 +145,7 @@ class ObjectListView(View):
             table.columns.show('pk')
 
         # Construct queryset for tags list
-        if hasattr(model, 'tags') and type(model.tags).__name__ is not 'ManyToManyDescriptor':
+        if is_taggable(model):
             tags = model.tags.annotate(count=Count('extras_taggeditem_items')).order_by('name')
         else:
             tags = None