Просмотр исходного кода

Initial work on custom model validation

jeremystretch 4 лет назад
Родитель
Сommit
3bfa1cbf41

+ 14 - 0
netbox/extras/signals.py

@@ -6,10 +6,12 @@ from django.conf import settings
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.db import DEFAULT_DB_ALIAS
 from django.db import DEFAULT_DB_ALIAS
 from django.db.models.signals import m2m_changed, post_save, pre_delete
 from django.db.models.signals import m2m_changed, post_save, pre_delete
+from django.dispatch import receiver
 from django.utils import timezone
 from django.utils import timezone
 from django_prometheus.models import model_deletes, model_inserts, model_updates
 from django_prometheus.models import model_deletes, model_inserts, model_updates
 from prometheus_client import Counter
 from prometheus_client import Counter
 
 
+from netbox.signals import post_clean
 from .choices import ObjectChangeActionChoices
 from .choices import ObjectChangeActionChoices
 from .models import CustomField, ObjectChange
 from .models import CustomField, ObjectChange
 from .webhooks import enqueue_object, get_snapshots, serialize_for_webhook
 from .webhooks import enqueue_object, get_snapshots, serialize_for_webhook
@@ -136,6 +138,18 @@ post_save.connect(handle_cf_renamed, sender=CustomField)
 pre_delete.connect(handle_cf_deleted, sender=CustomField)
 pre_delete.connect(handle_cf_deleted, sender=CustomField)
 
 
 
 
+#
+# Custom validation
+#
+
+@receiver(post_clean)
+def run_custom_validators(sender, instance, **kwargs):
+    model_name = f'{sender._meta.app_label}.{sender._meta.model_name}'
+    validators = settings.CUSTOM_VALIDATORS.get(model_name, [])
+    for validator in validators:
+        validator(instance)
+
+
 #
 #
 # Caching
 # Caching
 #
 #

+ 75 - 0
netbox/extras/tests/test_customvalidator.py

@@ -0,0 +1,75 @@
+from django.conf import settings
+from django.core.exceptions import ValidationError
+from django.test import TestCase, override_settings
+
+from dcim.models import Site
+from extras.validators import CustomValidator
+
+
+class MyValidator(CustomValidator):
+
+    def validate(self, instance):
+        if instance.name != 'foo':
+            self.fail("Name must be foo!")
+
+
+stock_validator = CustomValidator({
+    'name': {
+        'min_length': 5,
+        'max_length': 10,
+        'regex': r'\d{3}$',  # Ends with three digits
+    },
+    'asn': {
+        'min': 65000,
+        'max': 65100,
+    }
+})
+
+custom_validator = MyValidator()
+
+
+class CustomValidatorTest(TestCase):
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_configuration(self):
+        self.assertIn('dcim.site', settings.CUSTOM_VALIDATORS)
+        validator = settings.CUSTOM_VALIDATORS['dcim.site'][0]
+        self.assertIsInstance(validator, CustomValidator)
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_min(self):
+        with self.assertRaises(ValidationError):
+            Site(name='abcdef123', slug='abcdefghijk', asn=1).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_max(self):
+        with self.assertRaises(ValidationError):
+            Site(name='abcdef123', slug='abcdefghijk', asn=65535).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_min_length(self):
+        with self.assertRaises(ValidationError):
+            Site(name='abc', slug='abc', asn=65000).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_max_length(self):
+        with self.assertRaises(ValidationError):
+            Site(name='abcdefghijk', slug='abcdefghijk', asn=65000).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_regex(self):
+        with self.assertRaises(ValidationError):
+            Site(name='abcdefgh', slug='abcdefgh', asn=65000).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [stock_validator]})
+    def test_valid(self):
+        Site(name='abcdef123', slug='abcdef123', asn=65000).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [custom_validator]})
+    def test_custom_invalid(self):
+        with self.assertRaises(ValidationError):
+            Site(name='abc', slug='abc', asn=65000).clean()
+
+    @override_settings(CUSTOM_VALIDATORS={'dcim.site': [custom_validator]})
+    def test_custom_valid(self):
+        Site(name='foo', slug='foo', asn=65000).clean()

+ 72 - 0
netbox/extras/validators.py

@@ -0,0 +1,72 @@
+from django.core.exceptions import ValidationError
+from django.core import validators
+
+
+class CustomValidator:
+    """
+    This class enables the application of user-defined validation rules to NetBox models. It can be instantiated by
+    passing a dictionary of validation rules in the form {attribute: rules}, where 'rules' is a dictionary mapping
+    descriptors (e.g. min_length or regex) to values.
+
+    A CustomValidator instance is applied by calling it with the instance being validated:
+
+        validator = CustomValidator({'name': {'min_length: 10}})
+        site = Site(name='abcdef')
+        validator(site)  # Raises ValidationError
+
+    :param validation_rules: A dictionary mapping object attributes to validation rules
+    """
+    VALIDATORS = {
+        'min': validators.MinValueValidator,
+        'max': validators.MaxValueValidator,
+        'min_length': validators.MinLengthValidator,
+        'max_length': validators.MaxLengthValidator,
+        'regex': validators.RegexValidator,
+    }
+
+    def __init__(self, validation_rules=None):
+        self.validation_rules = validation_rules or {}
+        assert type(self.validation_rules) is dict, "Validation rules must be passed as a dictionary"
+
+    def __call__(self, instance):
+        # Validate instance attributes per validation rules
+        for attr_name, rules in self.validation_rules.items():
+            assert hasattr(instance, attr_name), f"Invalid attribute '{attr_name}' for {instance.__class__.__name__}"
+            attr = getattr(instance, attr_name)
+            for descriptor, value in rules.items():
+                validator = self.get_validator(descriptor, value)
+                try:
+                    validator(attr)
+                except ValidationError as exc:
+                    # Re-package the raised ValidationError to associate it with the specific attr
+                    raise ValidationError({attr_name: exc})
+
+        # Execute custom validation logic (if any)
+        self.validate(instance)
+
+    def get_validator(self, descriptor, value):
+        """
+        Instantiate and return the appropriate validator based on the descriptor given. For
+        example, 'min' returns MinValueValidator(value).
+        """
+        if descriptor not in self.VALIDATORS:
+            raise NotImplementedError(
+                f"Unknown validation type for {self.__class__.__name__}: '{descriptor}'"
+            )
+        validator_cls = self.VALIDATORS.get(descriptor)
+        return validator_cls(value)
+
+    def validate(self, instance):
+        """
+        Custom validation method, to be overridden by the user. Validation failures should
+        raise a ValidationError exception.
+        """
+        return
+
+    def fail(self, message, attr=None):
+        """
+        Raise a ValidationError exception. Associate the provided message with an attribute if specified.
+        """
+        if attr is not None:
+            raise ValidationError({attr: message})
+        raise ValidationError(message)

+ 14 - 0
netbox/netbox/configuration.example.py

@@ -106,6 +106,20 @@ CORS_ORIGIN_REGEX_WHITELIST = [
     # r'^(https?://)?(\w+\.)?example\.com$',
     # r'^(https?://)?(\w+\.)?example\.com$',
 ]
 ]
 
 
+# Specify any custom validators here, as a mapping of model to a list of validators classes. Validators should be
+# instances of or inherit from CustomValidator.
+# from extras.validators import CustomValidator
+CUSTOM_VALIDATORS = {
+    # 'dcim.site': [
+    #     CustomValidator({
+    #         'name': {
+    #             'min_length': 10,
+    #             'regex': r'\d{3}$',
+    #         }
+    #     })
+    # ],
+}
+
 # Set to True to enable server debugging. WARNING: Debugging introduces a substantial performance penalty and may reveal
 # Set to True to enable server debugging. WARNING: Debugging introduces a substantial performance penalty and may reveal
 # sensitive information about your installation. Only enable debugging while performing testing. Never enable debugging
 # sensitive information about your installation. Only enable debugging while performing testing. Never enable debugging
 # on a production system.
 # on a production system.

+ 19 - 4
netbox/netbox/models.py

@@ -9,6 +9,7 @@ from mptt.models import MPTTModel, TreeForeignKey
 from taggit.managers import TaggableManager
 from taggit.managers import TaggableManager
 
 
 from extras.choices import ObjectChangeActionChoices
 from extras.choices import ObjectChangeActionChoices
+from netbox.signals import post_clean
 from utilities.mptt import TreeManager
 from utilities.mptt import TreeManager
 from utilities.utils import serialize_object
 from utilities.utils import serialize_object
 
 
@@ -123,6 +124,20 @@ class CustomFieldsMixin(models.Model):
                 raise ValidationError(f"Missing required custom field '{cf.name}'.")
                 raise ValidationError(f"Missing required custom field '{cf.name}'.")
 
 
 
 
+class CustomValidationMixin(models.Model):
+    """
+    Enables user-configured validation rules for built-in models by extending the clean() method.
+    """
+    class Meta:
+        abstract = True
+
+    def clean(self):
+        super().clean()
+
+        # Send the post_clean signal
+        post_clean.send(sender=self.__class__, instance=self)
+
+
 #
 #
 # Base model classes
 # Base model classes
 
 
@@ -138,7 +153,7 @@ class BigIDModel(models.Model):
         abstract = True
         abstract = True
 
 
 
 
-class ChangeLoggedModel(ChangeLoggingMixin, BigIDModel):
+class ChangeLoggedModel(ChangeLoggingMixin, CustomValidationMixin, BigIDModel):
     """
     """
     Base model for all objects which support change logging.
     Base model for all objects which support change logging.
     """
     """
@@ -146,7 +161,7 @@ class ChangeLoggedModel(ChangeLoggingMixin, BigIDModel):
         abstract = True
         abstract = True
 
 
 
 
-class PrimaryModel(ChangeLoggingMixin, CustomFieldsMixin, BigIDModel):
+class PrimaryModel(ChangeLoggingMixin, CustomFieldsMixin, CustomValidationMixin, BigIDModel):
     """
     """
     Primary models represent real objects within the infrastructure being modeled.
     Primary models represent real objects within the infrastructure being modeled.
     """
     """
@@ -163,7 +178,7 @@ class PrimaryModel(ChangeLoggingMixin, CustomFieldsMixin, BigIDModel):
         abstract = True
         abstract = True
 
 
 
 
-class NestedGroupModel(ChangeLoggingMixin, CustomFieldsMixin, BigIDModel, MPTTModel):
+class NestedGroupModel(ChangeLoggingMixin, CustomFieldsMixin, CustomValidationMixin, BigIDModel, MPTTModel):
     """
     """
     Base model for objects which are used to form a hierarchy (regions, locations, etc.). These models nest
     Base model for objects which are used to form a hierarchy (regions, locations, etc.). These models nest
     recursively using MPTT. Within each parent, each child instance must have a unique name.
     recursively using MPTT. Within each parent, each child instance must have a unique name.
@@ -205,7 +220,7 @@ class NestedGroupModel(ChangeLoggingMixin, CustomFieldsMixin, BigIDModel, MPTTMo
             })
             })
 
 
 
 
-class OrganizationalModel(ChangeLoggingMixin, CustomFieldsMixin, BigIDModel):
+class OrganizationalModel(ChangeLoggingMixin, CustomFieldsMixin, CustomValidationMixin, BigIDModel):
     """
     """
     Organizational models are those which are used solely to categorize and qualify other objects, and do not convey
     Organizational models are those which are used solely to categorize and qualify other objects, and do not convey
     any real information about the infrastructure being modeled (for example, functional device roles). Organizational
     any real information about the infrastructure being modeled (for example, functional device roles). Organizational

+ 1 - 0
netbox/netbox/settings.py

@@ -74,6 +74,7 @@ CHANGELOG_RETENTION = getattr(configuration, 'CHANGELOG_RETENTION', 90)
 CORS_ORIGIN_ALLOW_ALL = getattr(configuration, 'CORS_ORIGIN_ALLOW_ALL', False)
 CORS_ORIGIN_ALLOW_ALL = getattr(configuration, 'CORS_ORIGIN_ALLOW_ALL', False)
 CORS_ORIGIN_REGEX_WHITELIST = getattr(configuration, 'CORS_ORIGIN_REGEX_WHITELIST', [])
 CORS_ORIGIN_REGEX_WHITELIST = getattr(configuration, 'CORS_ORIGIN_REGEX_WHITELIST', [])
 CORS_ORIGIN_WHITELIST = getattr(configuration, 'CORS_ORIGIN_WHITELIST', [])
 CORS_ORIGIN_WHITELIST = getattr(configuration, 'CORS_ORIGIN_WHITELIST', [])
+CUSTOM_VALIDATORS = getattr(configuration, 'CUSTOM_VALIDATORS', {})
 DATE_FORMAT = getattr(configuration, 'DATE_FORMAT', 'N j, Y')
 DATE_FORMAT = getattr(configuration, 'DATE_FORMAT', 'N j, Y')
 DATETIME_FORMAT = getattr(configuration, 'DATETIME_FORMAT', 'N j, Y g:i a')
 DATETIME_FORMAT = getattr(configuration, 'DATETIME_FORMAT', 'N j, Y g:i a')
 DEBUG = getattr(configuration, 'DEBUG', False)
 DEBUG = getattr(configuration, 'DEBUG', False)

+ 5 - 0
netbox/netbox/signals.py

@@ -0,0 +1,5 @@
+from django.dispatch import Signal
+
+
+# Signals that a model has completed its clean() method
+post_clean = Signal()