Просмотр исходного кода

Merge pull request #3988 from netbox-community/3509-ipaddress-script-vars

Closes #3509: Add IP address vars for custom scripts
Jeremy Stretch 6 лет назад
Родитель
Сommit
2445d1896b

+ 13 - 2
docs/additional-features/custom-scripts.md

@@ -124,7 +124,7 @@ Arbitrary text of any length. Renders as multi-line text input field.
 
 
 Stored a numeric integer. Options include:
 Stored a numeric integer. Options include:
 
 
-* `min_value:` - Minimum value
+* `min_value` - Minimum value
 * `max_value` - Maximum value
 * `max_value` - Maximum value
 
 
 ### BooleanVar
 ### BooleanVar
@@ -158,9 +158,20 @@ A NetBox object. The list of available objects is defined by the queryset parame
 
 
 An uploaded file. Note that uploaded files are present in memory only for the duration of the script's execution: They will not be save for future use.
 An uploaded file. Note that uploaded files are present in memory only for the duration of the script's execution: They will not be save for future use.
 
 
+### IPAddressVar
+
+An IPv4 or IPv6 address, without a mask. Returns a `netaddr.IPAddress` object.
+
+### IPAddressWithMaskVar
+
+An IPv4 or IPv6 address with a mask. Returns a `netaddr.IPNetwork` object which includes the mask.
+
 ### IPNetworkVar
 ### IPNetworkVar
 
 
-An IPv4 or IPv6 network with a mask.
+An IPv4 or IPv6 network with a mask. Returns a `netaddr.IPNetwork` object. Two attributes are available to validate the provided mask:
+
+* `min_prefix_length` - Minimum length of the mask (default: none)
+* `max_prefix_length` - Maximum length of the mask (default: none)
 
 
 ### Default Options
 ### Default Options
 
 

+ 1 - 0
docs/release-notes/version-2.7.md

@@ -3,6 +3,7 @@
 ## Enhancements
 ## Enhancements
 
 
 * [#3310](https://github.com/netbox-community/netbox/issues/3310) - Pre-select site/rack for B side when creating a new cable
 * [#3310](https://github.com/netbox-community/netbox/issues/3310) - Pre-select site/rack for B side when creating a new cable
+* [#3509](https://github.com/netbox-community/netbox/issues/3509) - Add IP address variables for custom scripts
 
 
 ## Bug Fixes
 ## Bug Fixes
 
 

+ 32 - 11
netbox/extras/scripts.py

@@ -14,10 +14,10 @@ from django.db import transaction
 from mptt.forms import TreeNodeChoiceField, TreeNodeMultipleChoiceField
 from mptt.forms import TreeNodeChoiceField, TreeNodeMultipleChoiceField
 from mptt.models import MPTTModel
 from mptt.models import MPTTModel
 
 
-from ipam.formfields import IPFormField
-from utilities.exceptions import AbortTransaction
-from utilities.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator
+from ipam.formfields import IPAddressFormField, IPNetworkFormField
+from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator
 from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_SUCCESS, LOG_WARNING
 from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_SUCCESS, LOG_WARNING
+from utilities.exceptions import AbortTransaction
 from .forms import ScriptForm
 from .forms import ScriptForm
 from .signals import purge_changelog
 from .signals import purge_changelog
 
 
@@ -27,6 +27,8 @@ __all__ = [
     'ChoiceVar',
     'ChoiceVar',
     'FileVar',
     'FileVar',
     'IntegerVar',
     'IntegerVar',
+    'IPAddressVar',
+    'IPAddressWithMaskVar',
     'IPNetworkVar',
     'IPNetworkVar',
     'MultiObjectVar',
     'MultiObjectVar',
     'ObjectVar',
     'ObjectVar',
@@ -48,15 +50,19 @@ class ScriptVariable:
 
 
     def __init__(self, label='', description='', default=None, required=True):
     def __init__(self, label='', description='', default=None, required=True):
 
 
-        # Default field attributes
-        self.field_attrs = {
-            'help_text': description,
-            'required': required
-        }
+        # Initialize field attributes
+        if not hasattr(self, 'field_attrs'):
+            self.field_attrs = {}
+        if description:
+            self.field_attrs['help_text'] = description
         if label:
         if label:
             self.field_attrs['label'] = label
             self.field_attrs['label'] = label
         if default:
         if default:
             self.field_attrs['initial'] = default
             self.field_attrs['initial'] = default
+        if required:
+            self.field_attrs['required'] = True
+        if 'validators' not in self.field_attrs:
+            self.field_attrs['validators'] = []
 
 
     def as_field(self):
     def as_field(self):
         """
         """
@@ -196,17 +202,32 @@ class FileVar(ScriptVariable):
     form_field = forms.FileField
     form_field = forms.FileField
 
 
 
 
+class IPAddressVar(ScriptVariable):
+    """
+    An IPv4 or IPv6 address without a mask.
+    """
+    form_field = IPAddressFormField
+
+
+class IPAddressWithMaskVar(ScriptVariable):
+    """
+    An IPv4 or IPv6 address with a mask.
+    """
+    form_field = IPNetworkFormField
+
+
 class IPNetworkVar(ScriptVariable):
 class IPNetworkVar(ScriptVariable):
     """
     """
     An IPv4 or IPv6 prefix.
     An IPv4 or IPv6 prefix.
     """
     """
-    form_field = IPFormField
+    form_field = IPNetworkFormField
+    field_attrs = {
+        'validators': [prefix_validator]
+    }
 
 
     def __init__(self, min_prefix_length=None, max_prefix_length=None, *args, **kwargs):
     def __init__(self, min_prefix_length=None, max_prefix_length=None, *args, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
-        self.field_attrs['validators'] = list()
-
         # Optional minimum/maximum prefix lengths
         # Optional minimum/maximum prefix lengths
         if min_prefix_length is not None:
         if min_prefix_length is not None:
             self.field_attrs['validators'].append(
             self.field_attrs['validators'].append(

+ 55 - 1
netbox/extras/tests/test_scripts.py

@@ -1,6 +1,6 @@
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test import TestCase
-from netaddr import IPNetwork
+from netaddr import IPAddress, IPNetwork
 
 
 from dcim.models import DeviceRole
 from dcim.models import DeviceRole
 from extras.scripts import *
 from extras.scripts import *
@@ -186,6 +186,54 @@ class ScriptVariablesTest(TestCase):
         self.assertTrue(form.is_valid())
         self.assertTrue(form.is_valid())
         self.assertEqual(form.cleaned_data['var1'], testfile)
         self.assertEqual(form.cleaned_data['var1'], testfile)
 
 
+    def test_ipaddressvar(self):
+
+        class TestScript(Script):
+
+            var1 = IPAddressVar()
+
+        # Validate IP network enforcement
+        data = {'var1': '1.2.3'}
+        form = TestScript().as_form(data, None)
+        self.assertFalse(form.is_valid())
+        self.assertIn('var1', form.errors)
+
+        # Validate IP mask exclusion
+        data = {'var1': '192.0.2.0/24'}
+        form = TestScript().as_form(data, None)
+        self.assertFalse(form.is_valid())
+        self.assertIn('var1', form.errors)
+
+        # Validate valid data
+        data = {'var1': '192.0.2.1'}
+        form = TestScript().as_form(data, None)
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['var1'], IPAddress(data['var1']))
+
+    def test_ipaddresswithmaskvar(self):
+
+        class TestScript(Script):
+
+            var1 = IPAddressWithMaskVar()
+
+        # Validate IP network enforcement
+        data = {'var1': '1.2.3'}
+        form = TestScript().as_form(data, None)
+        self.assertFalse(form.is_valid())
+        self.assertIn('var1', form.errors)
+
+        # Validate IP mask requirement
+        data = {'var1': '192.0.2.0'}
+        form = TestScript().as_form(data, None)
+        self.assertFalse(form.is_valid())
+        self.assertIn('var1', form.errors)
+
+        # Validate valid data
+        data = {'var1': '192.0.2.0/24'}
+        form = TestScript().as_form(data, None)
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['var1'], IPNetwork(data['var1']))
+
     def test_ipnetworkvar(self):
     def test_ipnetworkvar(self):
 
 
         class TestScript(Script):
         class TestScript(Script):
@@ -198,6 +246,12 @@ class ScriptVariablesTest(TestCase):
         self.assertFalse(form.is_valid())
         self.assertFalse(form.is_valid())
         self.assertIn('var1', form.errors)
         self.assertIn('var1', form.errors)
 
 
+        # Validate host IP check
+        data = {'var1': '192.0.2.1/24'}
+        form = TestScript().as_form(data, None)
+        self.assertFalse(form.is_valid())
+        self.assertIn('var1', form.errors)
+
         # Validate valid data
         # Validate valid data
         data = {'var1': '192.0.2.0/24'}
         data = {'var1': '192.0.2.0/24'}
         form = TestScript().as_form(data, None)
         form = TestScript().as_form(data, None)

+ 4 - 9
netbox/ipam/fields.py

@@ -2,13 +2,8 @@ from django.core.exceptions import ValidationError
 from django.db import models
 from django.db import models
 from netaddr import AddrFormatError, IPNetwork
 from netaddr import AddrFormatError, IPNetwork
 
 
-from . import lookups
-from .formfields import IPFormField
-
-
-def prefix_validator(prefix):
-    if prefix.ip != prefix.cidr.ip:
-        raise ValidationError("{} is not a valid prefix. Did you mean {}?".format(prefix, prefix.cidr))
+from . import lookups, validators
+from .formfields import IPNetworkFormField
 
 
 
 
 class BaseIPField(models.Field):
 class BaseIPField(models.Field):
@@ -38,7 +33,7 @@ class BaseIPField(models.Field):
         return str(self.to_python(value))
         return str(self.to_python(value))
 
 
     def form_class(self):
     def form_class(self):
-        return IPFormField
+        return IPNetworkFormField
 
 
     def formfield(self, **kwargs):
     def formfield(self, **kwargs):
         defaults = {'form_class': self.form_class()}
         defaults = {'form_class': self.form_class()}
@@ -51,7 +46,7 @@ class IPNetworkField(BaseIPField):
     IP prefix (network and mask)
     IP prefix (network and mask)
     """
     """
     description = "PostgreSQL CIDR field"
     description = "PostgreSQL CIDR field"
-    default_validators = [prefix_validator]
+    default_validators = [validators.prefix_validator]
 
 
     def db_type(self, connection):
     def db_type(self, connection):
         return 'cidr'
         return 'cidr'

+ 33 - 2
netbox/ipam/formfields.py

@@ -1,13 +1,44 @@
 from django import forms
 from django import forms
 from django.core.exceptions import ValidationError
 from django.core.exceptions import ValidationError
-from netaddr import IPNetwork, AddrFormatError
+from django.core.validators import validate_ipv4_address, validate_ipv6_address
+from netaddr import IPAddress, IPNetwork, AddrFormatError
 
 
 
 
 #
 #
 # Form fields
 # Form fields
 #
 #
 
 
-class IPFormField(forms.Field):
+class IPAddressFormField(forms.Field):
+    default_error_messages = {
+        'invalid': "Enter a valid IPv4 or IPv6 address (without a mask).",
+    }
+
+    def to_python(self, value):
+        if not value:
+            return None
+
+        if isinstance(value, IPAddress):
+            return value
+
+        # netaddr is a bit too liberal with what it accepts as a valid IP address. For example, '1.2.3' will become
+        # IPAddress('1.2.0.3'). Here, we employ Django's built-in IPv4 and IPv6 address validators as a sanity check.
+        try:
+            validate_ipv4_address(value)
+        except ValidationError:
+            try:
+                validate_ipv6_address(value)
+            except ValidationError:
+                raise ValidationError("Invalid IPv4/IPv6 address format: {}".format(value))
+
+        try:
+            return IPAddress(value)
+        except ValueError:
+            raise ValidationError('This field requires an IP address without a mask.')
+        except AddrFormatError:
+            raise ValidationError("Please specify a valid IPv4 or IPv6 address.")
+
+
+class IPNetworkFormField(forms.Field):
     default_error_messages = {
     default_error_messages = {
         'invalid': "Enter a valid IPv4 or IPv6 address (with CIDR mask).",
         'invalid': "Enter a valid IPv4 or IPv6 address (with CIDR mask).",
     }
     }

+ 23 - 1
netbox/ipam/validators.py

@@ -1,4 +1,26 @@
-from django.core.validators import RegexValidator
+from django.core.exceptions import ValidationError
+from django.core.validators import BaseValidator, RegexValidator
+
+
+def prefix_validator(prefix):
+    if prefix.ip != prefix.cidr.ip:
+        raise ValidationError("{} is not a valid prefix. Did you mean {}?".format(prefix, prefix.cidr))
+
+
+class MaxPrefixLengthValidator(BaseValidator):
+    message = 'The prefix length must be less than or equal to %(limit_value)s.'
+    code = 'max_prefix_length'
+
+    def compare(self, a, b):
+        return a.prefixlen > b
+
+
+class MinPrefixLengthValidator(BaseValidator):
+    message = 'The prefix length must be greater than or equal to %(limit_value)s.'
+    code = 'min_prefix_length'
+
+    def compare(self, a, b):
+        return a.prefixlen < b
 
 
 
 
 DNSValidator = RegexValidator(
 DNSValidator = RegexValidator(

+ 1 - 17
netbox/utilities/validators.py

@@ -1,6 +1,6 @@
 import re
 import re
 
 
-from django.core.validators import _lazy_re_compile, BaseValidator, URLValidator
+from django.core.validators import _lazy_re_compile, URLValidator
 
 
 
 
 class EnhancedURLValidator(URLValidator):
 class EnhancedURLValidator(URLValidator):
@@ -26,19 +26,3 @@ class EnhancedURLValidator(URLValidator):
         r'(?:[/?#][^\s]*)?'                 # Path
         r'(?:[/?#][^\s]*)?'                 # Path
         r'\Z', re.IGNORECASE)
         r'\Z', re.IGNORECASE)
     schemes = AnyURLScheme()
     schemes = AnyURLScheme()
-
-
-class MaxPrefixLengthValidator(BaseValidator):
-    message = 'The prefix length must be less than or equal to %(limit_value)s.'
-    code = 'max_prefix_length'
-
-    def compare(self, a, b):
-        return a.prefixlen > b
-
-
-class MinPrefixLengthValidator(BaseValidator):
-    message = 'The prefix length must be greater than or equal to %(limit_value)s.'
-    code = 'min_prefix_length'
-
-    def compare(self, a, b):
-        return a.prefixlen < b