scripts.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import inspect
  2. import json
  3. import logging
  4. import os
  5. import pkgutil
  6. import time
  7. import traceback
  8. from collections import OrderedDict
  9. import yaml
  10. from django import forms
  11. from django.conf import settings
  12. from django.core.validators import RegexValidator
  13. from django.db import transaction
  14. from mptt.forms import TreeNodeChoiceField, TreeNodeMultipleChoiceField
  15. from mptt.models import MPTTModel
  16. from ipam.formfields import IPAddressFormField, IPNetworkFormField
  17. from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator
  18. from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_SUCCESS, LOG_WARNING
  19. from utilities.exceptions import AbortTransaction
  20. from utilities.forms import DynamicModelChoiceField, DynamicModelMultipleChoiceField
  21. from .forms import ScriptForm
  22. from .signals import purge_changelog
  23. __all__ = [
  24. 'BaseScript',
  25. 'BooleanVar',
  26. 'ChoiceVar',
  27. 'FileVar',
  28. 'IntegerVar',
  29. 'IPAddressVar',
  30. 'IPAddressWithMaskVar',
  31. 'IPNetworkVar',
  32. 'MultiObjectVar',
  33. 'ObjectVar',
  34. 'Script',
  35. 'StringVar',
  36. 'TextVar',
  37. ]
  38. #
  39. # Script variables
  40. #
  41. class ScriptVariable:
  42. """
  43. Base model for script variables
  44. """
  45. form_field = forms.CharField
  46. def __init__(self, label='', description='', default=None, required=True, widget=None):
  47. # Initialize field attributes
  48. if not hasattr(self, 'field_attrs'):
  49. self.field_attrs = {}
  50. if label:
  51. self.field_attrs['label'] = label
  52. if description:
  53. self.field_attrs['help_text'] = description
  54. if default:
  55. self.field_attrs['initial'] = default
  56. if widget:
  57. self.field_attrs['widget'] = widget
  58. self.field_attrs['required'] = required
  59. def as_field(self):
  60. """
  61. Render the variable as a Django form field.
  62. """
  63. form_field = self.form_field(**self.field_attrs)
  64. if not isinstance(form_field.widget, forms.CheckboxInput):
  65. if form_field.widget.attrs and 'class' in form_field.widget.attrs.keys():
  66. form_field.widget.attrs['class'] += ' form-control'
  67. else:
  68. form_field.widget.attrs['class'] = 'form-control'
  69. return form_field
  70. class StringVar(ScriptVariable):
  71. """
  72. Character string representation. Can enforce minimum/maximum length and/or regex validation.
  73. """
  74. def __init__(self, min_length=None, max_length=None, regex=None, *args, **kwargs):
  75. super().__init__(*args, **kwargs)
  76. # Optional minimum/maximum lengths
  77. if min_length:
  78. self.field_attrs['min_length'] = min_length
  79. if max_length:
  80. self.field_attrs['max_length'] = max_length
  81. # Optional regular expression validation
  82. if regex:
  83. self.field_attrs['validators'] = [
  84. RegexValidator(
  85. regex=regex,
  86. message='Invalid value. Must match regex: {}'.format(regex),
  87. code='invalid'
  88. )
  89. ]
  90. class TextVar(ScriptVariable):
  91. """
  92. Free-form text data. Renders as a <textarea>.
  93. """
  94. form_field = forms.CharField
  95. def __init__(self, *args, **kwargs):
  96. super().__init__(*args, **kwargs)
  97. self.field_attrs['widget'] = forms.Textarea
  98. class IntegerVar(ScriptVariable):
  99. """
  100. Integer representation. Can enforce minimum/maximum values.
  101. """
  102. form_field = forms.IntegerField
  103. def __init__(self, min_value=None, max_value=None, *args, **kwargs):
  104. super().__init__(*args, **kwargs)
  105. # Optional minimum/maximum values
  106. if min_value:
  107. self.field_attrs['min_value'] = min_value
  108. if max_value:
  109. self.field_attrs['max_value'] = max_value
  110. class BooleanVar(ScriptVariable):
  111. """
  112. Boolean representation (true/false). Renders as a checkbox.
  113. """
  114. form_field = forms.BooleanField
  115. def __init__(self, *args, **kwargs):
  116. super().__init__(*args, **kwargs)
  117. # Boolean fields cannot be required
  118. self.field_attrs['required'] = False
  119. class ChoiceVar(ScriptVariable):
  120. """
  121. Select one of several predefined static choices, passed as a list of two-tuples. Example:
  122. color = ChoiceVar(
  123. choices=(
  124. ('#ff0000', 'Red'),
  125. ('#00ff00', 'Green'),
  126. ('#0000ff', 'Blue')
  127. )
  128. )
  129. """
  130. form_field = forms.ChoiceField
  131. def __init__(self, choices, *args, **kwargs):
  132. super().__init__(*args, **kwargs)
  133. # Set field choices
  134. self.field_attrs['choices'] = choices
  135. class ObjectVar(ScriptVariable):
  136. """
  137. NetBox object representation. The provided QuerySet will determine the choices available.
  138. """
  139. form_field = DynamicModelChoiceField
  140. def __init__(self, queryset, *args, **kwargs):
  141. super().__init__(*args, **kwargs)
  142. # Queryset for field choices
  143. self.field_attrs['queryset'] = queryset
  144. # Update form field for MPTT (nested) objects
  145. if issubclass(queryset.model, MPTTModel):
  146. self.form_field = TreeNodeChoiceField
  147. class MultiObjectVar(ScriptVariable):
  148. """
  149. Like ObjectVar, but can represent one or more objects.
  150. """
  151. form_field = DynamicModelMultipleChoiceField
  152. def __init__(self, queryset, *args, **kwargs):
  153. super().__init__(*args, **kwargs)
  154. # Queryset for field choices
  155. self.field_attrs['queryset'] = queryset
  156. # Update form field for MPTT (nested) objects
  157. if issubclass(queryset.model, MPTTModel):
  158. self.form_field = TreeNodeMultipleChoiceField
  159. class FileVar(ScriptVariable):
  160. """
  161. An uploaded file.
  162. """
  163. form_field = forms.FileField
  164. class IPAddressVar(ScriptVariable):
  165. """
  166. An IPv4 or IPv6 address without a mask.
  167. """
  168. form_field = IPAddressFormField
  169. class IPAddressWithMaskVar(ScriptVariable):
  170. """
  171. An IPv4 or IPv6 address with a mask.
  172. """
  173. form_field = IPNetworkFormField
  174. class IPNetworkVar(ScriptVariable):
  175. """
  176. An IPv4 or IPv6 prefix.
  177. """
  178. form_field = IPNetworkFormField
  179. def __init__(self, min_prefix_length=None, max_prefix_length=None, *args, **kwargs):
  180. super().__init__(*args, **kwargs)
  181. # Set prefix validator and optional minimum/maximum prefix lengths
  182. self.field_attrs['validators'] = [prefix_validator]
  183. if min_prefix_length is not None:
  184. self.field_attrs['validators'].append(
  185. MinPrefixLengthValidator(min_prefix_length)
  186. )
  187. if max_prefix_length is not None:
  188. self.field_attrs['validators'].append(
  189. MaxPrefixLengthValidator(max_prefix_length)
  190. )
  191. #
  192. # Scripts
  193. #
  194. class BaseScript:
  195. """
  196. Base model for custom scripts. User classes should inherit from this model if they want to extend Script
  197. functionality for use in other subclasses.
  198. """
  199. class Meta:
  200. pass
  201. def __init__(self):
  202. # Initiate the log
  203. self.logger = logging.getLogger(f"netbox.scripts.{self.module()}.{self.__class__.__name__}")
  204. self.log = []
  205. # Declare the placeholder for the current request
  206. self.request = None
  207. # Grab some info about the script
  208. self.filename = inspect.getfile(self.__class__)
  209. self.source = inspect.getsource(self.__class__)
  210. def __str__(self):
  211. return getattr(self.Meta, 'name', self.__class__.__name__)
  212. @classmethod
  213. def module(cls):
  214. return cls.__module__
  215. @classmethod
  216. def _get_vars(cls):
  217. vars = OrderedDict()
  218. for name, attr in cls.__dict__.items():
  219. if name not in vars and issubclass(attr.__class__, ScriptVariable):
  220. vars[name] = attr
  221. return vars
  222. def run(self, data, commit):
  223. raise NotImplementedError("The script must define a run() method.")
  224. def as_form(self, data=None, files=None, initial=None):
  225. """
  226. Return a Django form suitable for populating the context data required to run this Script.
  227. """
  228. # Create a dynamic ScriptForm subclass from script variables
  229. fields = {
  230. name: var.as_field() for name, var in self._get_vars().items()
  231. }
  232. FormClass = type('ScriptForm', (ScriptForm,), fields)
  233. form = FormClass(data, files, initial=initial)
  234. # Set initial "commit" checkbox state based on the script's Meta parameter
  235. form.fields['_commit'].initial = getattr(self.Meta, 'commit_default', True)
  236. return form
  237. # Logging
  238. def log_debug(self, message):
  239. self.logger.log(logging.DEBUG, message)
  240. self.log.append((LOG_DEFAULT, message))
  241. def log_success(self, message):
  242. self.logger.log(logging.INFO, message) # No syslog equivalent for SUCCESS
  243. self.log.append((LOG_SUCCESS, message))
  244. def log_info(self, message):
  245. self.logger.log(logging.INFO, message)
  246. self.log.append((LOG_INFO, message))
  247. def log_warning(self, message):
  248. self.logger.log(logging.WARNING, message)
  249. self.log.append((LOG_WARNING, message))
  250. def log_failure(self, message):
  251. self.logger.log(logging.ERROR, message)
  252. self.log.append((LOG_FAILURE, message))
  253. # Convenience functions
  254. def load_yaml(self, filename):
  255. """
  256. Return data from a YAML file
  257. """
  258. file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
  259. with open(file_path, 'r') as datafile:
  260. data = yaml.load(datafile)
  261. return data
  262. def load_json(self, filename):
  263. """
  264. Return data from a JSON file
  265. """
  266. file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
  267. with open(file_path, 'r') as datafile:
  268. data = json.load(datafile)
  269. return data
  270. class Script(BaseScript):
  271. """
  272. Classes which inherit this model will appear in the list of available scripts.
  273. """
  274. pass
  275. #
  276. # Functions
  277. #
  278. def is_script(obj):
  279. """
  280. Returns True if the object is a Script.
  281. """
  282. try:
  283. return issubclass(obj, Script) and obj != Script
  284. except TypeError:
  285. return False
  286. def is_variable(obj):
  287. """
  288. Returns True if the object is a ScriptVariable.
  289. """
  290. return isinstance(obj, ScriptVariable)
  291. def run_script(script, data, request, commit=True):
  292. """
  293. A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It
  294. exists outside of the Script class to ensure it cannot be overridden by a script author.
  295. """
  296. output = None
  297. start_time = None
  298. end_time = None
  299. script_name = script.__class__.__name__
  300. logger = logging.getLogger(f"netbox.scripts.{script.module()}.{script_name}")
  301. logger.info(f"Running script (commit={commit})")
  302. # Add files to form data
  303. files = request.FILES
  304. for field_name, fileobj in files.items():
  305. data[field_name] = fileobj
  306. # Add the current request as a property of the script
  307. script.request = request
  308. # Determine whether the script accepts a 'commit' argument (this was introduced in v2.7.8)
  309. kwargs = {
  310. 'data': data
  311. }
  312. if 'commit' in inspect.signature(script.run).parameters:
  313. kwargs['commit'] = commit
  314. try:
  315. with transaction.atomic():
  316. start_time = time.time()
  317. output = script.run(**kwargs)
  318. end_time = time.time()
  319. if not commit:
  320. raise AbortTransaction()
  321. except AbortTransaction:
  322. pass
  323. except Exception as e:
  324. stacktrace = traceback.format_exc()
  325. script.log_failure(
  326. "An exception occurred: `{}: {}`\n```\n{}\n```".format(type(e).__name__, e, stacktrace)
  327. )
  328. logger.error(f"Exception raised during script execution: {e}")
  329. commit = False
  330. finally:
  331. if not commit:
  332. # Delete all pending changelog entries
  333. purge_changelog.send(Script)
  334. script.log_info(
  335. "Database changes have been reverted automatically."
  336. )
  337. # Calculate execution time
  338. if end_time is not None:
  339. execution_time = end_time - start_time
  340. logger.info(f"Script completed in {execution_time:.4f} seconds")
  341. else:
  342. execution_time = None
  343. return output, execution_time
  344. def get_scripts(use_names=False):
  345. """
  346. Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
  347. defined name in place of the actual module name.
  348. """
  349. scripts = OrderedDict()
  350. # Iterate through all modules within the reports path. These are the user-created files in which reports are
  351. # defined.
  352. for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
  353. module = importer.find_module(module_name).load_module(module_name)
  354. if use_names and hasattr(module, 'name'):
  355. module_name = module.name
  356. module_scripts = OrderedDict()
  357. for name, cls in inspect.getmembers(module, is_script):
  358. module_scripts[name] = cls
  359. if module_scripts:
  360. scripts[module_name] = module_scripts
  361. return scripts
  362. def get_script(module_name, script_name):
  363. """
  364. Retrieve a script class by module and name. Returns None if the script does not exist.
  365. """
  366. scripts = get_scripts()
  367. module = scripts.get(module_name)
  368. if module:
  369. return module.get(script_name)