Просмотр исходного кода

Change the way we invalidate the module cache to support reloading code from subpackages

kkthxbye-code 3 лет назад
Родитель
Сommit
c78022a74c
1 измененных файлов с 19 добавлено и 7 удалено
  1. 19 7
      netbox/extras/scripts.py

+ 19 - 7
netbox/extras/scripts.py

@@ -524,27 +524,39 @@ def get_scripts(use_names=False):
     defined name in place of the actual module name.
     """
     scripts = {}
-    # Iterate through all modules within the scripts path. These are the user-created files in which reports are
+
+    # Get all modules within the scripts path. These are the user-created files in which scripts are
     # defined.
-    for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
-        # Use a lock as removing and loading modules is not thread safe
-        with lock:
-            # Remove cached module to ensure consistency with filesystem
-            if module_name in sys.modules:
+    modules = list(pkgutil.iter_modules([settings.SCRIPTS_ROOT]))
+    modules_bases = set([name.split(".")[0] for _, name, _ in modules])
+
+    # Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is
+    # removed from sys.modules while another thread is importing
+    with lock:
+        for module_name in list(sys.modules.keys()):
+            # Everything sharing a base module path with a module in the script folder is removed.
+            # We also remove all modules with a base module called "scripts". This allows modifying imported
+            # non-script modules without having to reload the RQ worker.
+            module_base = module_name.split(".")[0]
+            if module_base == "scripts" or module_base in modules_bases:
                 del sys.modules[module_name]
 
-            module = importer.find_module(module_name).load_module(module_name)
+    for importer, module_name, _ in modules:
+        module = importer.find_module(module_name).load_module(module_name)
 
         if use_names and hasattr(module, 'name'):
             module_name = module.name
+
         module_scripts = {}
         script_order = getattr(module, "script_order", ())
         ordered_scripts = [cls for cls in script_order if is_script(cls)]
         unordered_scripts = [cls for _, cls in inspect.getmembers(module, is_script) if cls not in script_order]
+
         for cls in [*ordered_scripts, *unordered_scripts]:
             # For scripts in submodules use the full import path w/o the root module as the name
             script_name = cls.full_name.split(".", maxsplit=1)[1]
             module_scripts[script_name] = cls
+
         if module_scripts:
             scripts[module_name] = module_scripts