Просмотр исходного кода

Merge pull request #3654 from netbox-community/3538-scripts-api

3538: Add custom script API endpoints
Jeremy Stretch 6 лет назад
Родитель
Сommit
56a248e601

+ 46 - 0
netbox/extras/api/serializers.py

@@ -200,6 +200,52 @@ class ReportDetailSerializer(ReportSerializer):
     result = ReportResultSerializer()
 
 
+#
+# Scripts
+#
+
+class ScriptSerializer(serializers.Serializer):
+    id = serializers.SerializerMethodField(read_only=True)
+    name = serializers.SerializerMethodField(read_only=True)
+    description = serializers.SerializerMethodField(read_only=True)
+    vars = serializers.SerializerMethodField(read_only=True)
+
+    def get_id(self, instance):
+        return '{}.{}'.format(instance.__module__, instance.__name__)
+
+    def get_name(self, instance):
+        return getattr(instance.Meta, 'name', instance.__name__)
+
+    def get_description(self, instance):
+        return getattr(instance.Meta, 'description', '')
+
+    def get_vars(self, instance):
+        return {
+            k: v.__class__.__name__ for k, v in instance._get_vars().items()
+        }
+
+
+class ScriptInputSerializer(serializers.Serializer):
+    data = serializers.JSONField()
+    commit = serializers.BooleanField()
+
+
+class ScriptLogMessageSerializer(serializers.Serializer):
+    status = serializers.SerializerMethodField(read_only=True)
+    message = serializers.SerializerMethodField(read_only=True)
+
+    def get_status(self, instance):
+        return LOG_LEVEL_CODES.get(instance[0])
+
+    def get_message(self, instance):
+        return instance[1]
+
+
+class ScriptOutputSerializer(serializers.Serializer):
+    log = ScriptLogMessageSerializer(many=True, read_only=True)
+    output = serializers.CharField(read_only=True)
+
+
 #
 # Change logging
 #

+ 3 - 0
netbox/extras/api/urls.py

@@ -38,6 +38,9 @@ router.register(r'config-contexts', views.ConfigContextViewSet)
 # Reports
 router.register(r'reports', views.ReportViewSet, basename='report')
 
+# Scripts
+router.register(r'scripts', views.ScriptViewSet, basename='script')
+
 # Change logging
 router.register(r'object-changes', views.ObjectChangeViewSet)
 

+ 52 - 0
netbox/extras/api/views.py

@@ -3,6 +3,7 @@ from collections import OrderedDict
 from django.contrib.contenttypes.models import ContentType
 from django.db.models import Count
 from django.http import Http404
+from rest_framework import status
 from rest_framework.decorators import action
 from rest_framework.exceptions import PermissionDenied
 from rest_framework.response import Response
@@ -13,6 +14,7 @@ from extras.models import (
     ConfigContext, CustomFieldChoice, ExportTemplate, Graph, ImageAttachment, ObjectChange, ReportResult, Tag,
 )
 from extras.reports import get_report, get_reports
+from extras.scripts import get_script, get_scripts
 from utilities.api import FieldChoicesViewSet, IsAuthenticatedOrLoginNotRequired, ModelViewSet
 from . import serializers
 
@@ -222,6 +224,56 @@ class ReportViewSet(ViewSet):
         return Response(serializer.data)
 
 
+#
+# Scripts
+#
+
+class ScriptViewSet(ViewSet):
+    permission_classes = [IsAuthenticatedOrLoginNotRequired]
+    _ignore_model_permissions = True
+    exclude_from_schema = True
+    lookup_value_regex = '[^/]+'  # Allow dots
+
+    def _get_script(self, pk):
+        module_name, script_name = pk.split('.')
+        script = get_script(module_name, script_name)
+        if script is None:
+            raise Http404
+        return script
+
+    def list(self, request):
+
+        flat_list = []
+        for script_list in get_scripts().values():
+            flat_list.extend(script_list.values())
+
+        serializer = serializers.ScriptSerializer(flat_list, many=True, context={'request': request})
+
+        return Response(serializer.data)
+
+    def retrieve(self, request, pk):
+        script = self._get_script(pk)
+        serializer = serializers.ScriptSerializer(script, context={'request': request})
+
+        return Response(serializer.data)
+
+    def post(self, request, pk):
+        """
+        Run a Script identified as "<module>.<script>".
+        """
+        script = self._get_script(pk)()
+        input_serializer = serializers.ScriptInputSerializer(data=request.data)
+
+        if input_serializer.is_valid():
+            output = script.run(input_serializer.data['data'])
+            script.output = output
+            output_serializer = serializers.ScriptOutputSerializer(script)
+
+            return Response(output_serializer.data)
+
+        return Response(input_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+
 #
 # Change logging
 #

+ 25 - 6
netbox/extras/scripts.py

@@ -220,16 +220,21 @@ class BaseScript:
     def __str__(self):
         return getattr(self.Meta, 'name', self.__class__.__name__)
 
-    def _get_vars(self):
+    @classmethod
+    def module(cls):
+        return cls.__module__
+
+    @classmethod
+    def _get_vars(cls):
         vars = OrderedDict()
 
         # Infer order from Meta.field_order (Python 3.5 and lower)
-        field_order = getattr(self.Meta, 'field_order', [])
+        field_order = getattr(cls.Meta, 'field_order', [])
         for name in field_order:
-            vars[name] = getattr(self, name)
+            vars[name] = getattr(cls, name)
 
         # Default to order of declaration on class
-        for name, attr in self.__class__.__dict__.items():
+        for name, attr in cls.__dict__.items():
             if name not in vars and issubclass(attr.__class__, ScriptVariable):
                 vars[name] = attr
 
@@ -360,14 +365,18 @@ def run_script(script, data, files, commit=True):
     return output, execution_time
 
 
-def get_scripts():
+def get_scripts(use_names=False):
+    """
+    Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
+    defined name in place of the actual module name.
+    """
     scripts = OrderedDict()
 
     # Iterate through all modules within the reports path. These are the user-created files in which reports are
     # defined.
     for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
         module = importer.find_module(module_name).load_module(module_name)
-        if hasattr(module, 'name'):
+        if use_names and hasattr(module, 'name'):
             module_name = module.name
         module_scripts = OrderedDict()
         for name, cls in inspect.getmembers(module, is_script):
@@ -375,3 +384,13 @@ def get_scripts():
         scripts[module_name] = module_scripts
 
     return scripts
+
+
+def get_script(module_name, script_name):
+    """
+    Retrieve a script class by module and name. Returns None if the script does not exist.
+    """
+    scripts = get_scripts()
+    module = scripts.get(module_name)
+    if module:
+        return module.get(script_name)

+ 67 - 0
netbox/extras/tests/test_api.py

@@ -3,8 +3,10 @@ from django.urls import reverse
 from rest_framework import status
 
 from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Platform, Region, Site
+from extras.api.views import ScriptViewSet
 from extras.constants import GRAPH_TYPE_SITE
 from extras.models import ConfigContext, Graph, ExportTemplate, Tag
+from extras.scripts import BooleanVar, IntegerVar, Script, StringVar
 from tenancy.models import Tenant, TenantGroup
 from utilities.testing import APITestCase
 
@@ -520,3 +522,68 @@ class ConfigContextTest(APITestCase):
         configcontext6.sites.add(site2)
         rendered_context = device.get_config_context()
         self.assertEqual(rendered_context['bar'], 456)
+
+
+class ScriptTest(APITestCase):
+
+    class TestScript(Script):
+
+        class Meta:
+            name = "Test script"
+
+        var1 = StringVar()
+        var2 = IntegerVar()
+        var3 = BooleanVar()
+
+        def run(self, data):
+
+            self.log_info(data['var1'])
+            self.log_success(data['var2'])
+            self.log_failure(data['var3'])
+
+            return 'Script complete'
+
+    def get_test_script(self, *args):
+        return self.TestScript
+
+    def setUp(self):
+
+        super().setUp()
+
+        # Monkey-patch the API viewset's _get_script method to return our test script above
+        ScriptViewSet._get_script = self.get_test_script
+
+    def test_get_script(self):
+
+        url = reverse('extras-api:script-detail', kwargs={'pk': None})
+        response = self.client.get(url, **self.header)
+
+        self.assertEqual(response.data['name'], self.TestScript.Meta.name)
+        self.assertEqual(response.data['vars']['var1'], 'StringVar')
+        self.assertEqual(response.data['vars']['var2'], 'IntegerVar')
+        self.assertEqual(response.data['vars']['var3'], 'BooleanVar')
+
+    def test_run_script(self):
+
+        script_data = {
+            'var1': 'FooBar',
+            'var2': 123,
+            'var3': False,
+        }
+
+        data = {
+            'data': script_data,
+            'commit': True,
+        }
+
+        url = reverse('extras-api:script-detail', kwargs={'pk': None})
+        response = self.client.post(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        self.assertEqual(response.data['log'][0]['status'], 'info')
+        self.assertEqual(response.data['log'][0]['message'], script_data['var1'])
+        self.assertEqual(response.data['log'][1]['status'], 'success')
+        self.assertEqual(response.data['log'][1]['message'], script_data['var2'])
+        self.assertEqual(response.data['log'][2]['status'], 'failure')
+        self.assertEqual(response.data['log'][2]['message'], script_data['var3'])
+        self.assertEqual(response.data['output'], 'Script complete')

+ 1 - 1
netbox/extras/views.py

@@ -375,7 +375,7 @@ class ScriptListView(PermissionRequiredMixin, View):
     def get(self, request):
 
         return render(request, 'extras/script_list.html', {
-            'scripts': get_scripts(),
+            'scripts': get_scripts(use_names=True),
         })
 
 

+ 1 - 1
netbox/templates/extras/script_list.html

@@ -19,7 +19,7 @@
                             {% for class_name, script in module_scripts.items %}
                                 <tr>
                                     <td>
-                                        <a href="{% url 'extras:script' module=module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
+                                        <a href="{% url 'extras:script' module=script.module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
                                     </td>
                                     <td>{{ script.Meta.description }}</td>
                                 </tr>