Просмотр исходного кода

Fixes #7628: Fix load_yaml method for custom scripts

jeremystretch 4 лет назад
Родитель
Сommit
b56cae24c5
3 измененных файлов с 53 добавлено и 2 удалено
  1. 1 0
      docs/release-notes/version-3.0.md
  2. 6 2
      netbox/extras/scripts.py
  3. 46 0
      netbox/extras/tests/test_scripts.py

+ 1 - 0
docs/release-notes/version-3.0.md

@@ -5,6 +5,7 @@
 ### Bug Fixes
 
 * [#7612](https://github.com/netbox-community/netbox/issues/7612) - Strip HTML from custom field descriptions
+* [#7628](https://github.com/netbox-community/netbox/issues/7628) - Fix `load_yaml` method for custom scripts
 
 ---
 

+ 6 - 2
netbox/extras/scripts.py

@@ -4,7 +4,6 @@ import logging
 import os
 import pkgutil
 import traceback
-import warnings
 from collections import OrderedDict
 
 import yaml
@@ -345,9 +344,14 @@ class BaseScript:
         """
         Return data from a YAML file
         """
+        try:
+            from yaml import CLoader as Loader
+        except ImportError:
+            from yaml import Loader
+
         file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
         with open(file_path, 'r') as datafile:
-            data = yaml.load(datafile)
+            data = yaml.load(datafile, Loader=Loader)
 
         return data
 

+ 46 - 0
netbox/extras/tests/test_scripts.py

@@ -1,3 +1,5 @@
+import tempfile
+
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from netaddr import IPAddress, IPNetwork
@@ -11,6 +13,50 @@ CHOICES = (
     ('0000ff', 'Blue')
 )
 
+YAML_DATA = """
+Foo: 123
+Bar: 456
+Baz:
+ - A
+ - B
+ - C
+"""
+
+JSON_DATA = """
+{
+  "Foo": 123,
+  "Bar": 456,
+  "Baz": ["A", "B", "C"]
+}
+"""
+
+
+class ScriptTest(TestCase):
+
+    def test_load_yaml(self):
+        datafile = tempfile.NamedTemporaryFile()
+        datafile.write(bytes(YAML_DATA, 'UTF-8'))
+        datafile.seek(0)
+
+        data = Script().load_yaml(datafile.name)
+        self.assertEqual(data, {
+            'Foo': 123,
+            'Bar': 456,
+            'Baz': ['A', 'B', 'C'],
+        })
+
+    def test_load_json(self):
+        datafile = tempfile.NamedTemporaryFile()
+        datafile.write(bytes(JSON_DATA, 'UTF-8'))
+        datafile.seek(0)
+
+        data = Script().load_json(datafile.name)
+        self.assertEqual(data, {
+            'Foo': 123,
+            'Bar': 456,
+            'Baz': ['A', 'B', 'C'],
+        })
+
 
 class ScriptVariablesTest(TestCase):