scripts.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from collections import OrderedDict
  2. import inspect
  3. import json
  4. import os
  5. import pkgutil
  6. import time
  7. import traceback
  8. import yaml
  9. from django import forms
  10. from django.conf import settings
  11. from django.core.validators import RegexValidator
  12. from django.db import transaction
  13. from mptt.forms import TreeNodeChoiceField
  14. from mptt.models import MPTTModel
  15. from ipam.formfields import IPFormField
  16. from utilities.exceptions import AbortTransaction
  17. from utilities.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator
  18. from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_SUCCESS, LOG_WARNING
  19. from .forms import ScriptForm
  20. from .signals import purge_changelog
  21. __all__ = [
  22. 'BaseScript',
  23. 'BooleanVar',
  24. 'FileVar',
  25. 'IntegerVar',
  26. 'IPNetworkVar',
  27. 'ObjectVar',
  28. 'Script',
  29. 'StringVar',
  30. 'TextVar',
  31. ]
  32. #
  33. # Script variables
  34. #
  35. class ScriptVariable:
  36. """
  37. Base model for script variables
  38. """
  39. form_field = forms.CharField
  40. def __init__(self, label='', description='', default=None, required=True):
  41. # Default field attributes
  42. self.field_attrs = {
  43. 'help_text': description,
  44. 'required': required
  45. }
  46. if label:
  47. self.field_attrs['label'] = label
  48. if default:
  49. self.field_attrs['initial'] = default
  50. def as_field(self):
  51. """
  52. Render the variable as a Django form field.
  53. """
  54. form_field = self.form_field(**self.field_attrs)
  55. if not isinstance(form_field.widget, forms.CheckboxInput):
  56. form_field.widget.attrs['class'] = 'form-control'
  57. return form_field
  58. class StringVar(ScriptVariable):
  59. """
  60. Character string representation. Can enforce minimum/maximum length and/or regex validation.
  61. """
  62. def __init__(self, min_length=None, max_length=None, regex=None, *args, **kwargs):
  63. super().__init__(*args, **kwargs)
  64. # Optional minimum/maximum lengths
  65. if min_length:
  66. self.field_attrs['min_length'] = min_length
  67. if max_length:
  68. self.field_attrs['max_length'] = max_length
  69. # Optional regular expression validation
  70. if regex:
  71. self.field_attrs['validators'] = [
  72. RegexValidator(
  73. regex=regex,
  74. message='Invalid value. Must match regex: {}'.format(regex),
  75. code='invalid'
  76. )
  77. ]
  78. class TextVar(ScriptVariable):
  79. """
  80. Free-form text data. Renders as a <textarea>.
  81. """
  82. form_field = forms.CharField
  83. def __init__(self, *args, **kwargs):
  84. super().__init__(*args, **kwargs)
  85. self.field_attrs['widget'] = forms.Textarea
  86. class IntegerVar(ScriptVariable):
  87. """
  88. Integer representation. Can enforce minimum/maximum values.
  89. """
  90. form_field = forms.IntegerField
  91. def __init__(self, min_value=None, max_value=None, *args, **kwargs):
  92. super().__init__(*args, **kwargs)
  93. # Optional minimum/maximum values
  94. if min_value:
  95. self.field_attrs['min_value'] = min_value
  96. if max_value:
  97. self.field_attrs['max_value'] = max_value
  98. class BooleanVar(ScriptVariable):
  99. """
  100. Boolean representation (true/false). Renders as a checkbox.
  101. """
  102. form_field = forms.BooleanField
  103. def __init__(self, *args, **kwargs):
  104. super().__init__(*args, **kwargs)
  105. # Boolean fields cannot be required
  106. self.field_attrs['required'] = False
  107. class ObjectVar(ScriptVariable):
  108. """
  109. NetBox object representation. The provided QuerySet will determine the choices available.
  110. """
  111. form_field = forms.ModelChoiceField
  112. def __init__(self, queryset, *args, **kwargs):
  113. super().__init__(*args, **kwargs)
  114. # Queryset for field choices
  115. self.field_attrs['queryset'] = queryset
  116. # Update form field for MPTT (nested) objects
  117. if issubclass(queryset.model, MPTTModel):
  118. self.form_field = TreeNodeChoiceField
  119. class FileVar(ScriptVariable):
  120. """
  121. An uploaded file.
  122. """
  123. form_field = forms.FileField
  124. class IPNetworkVar(ScriptVariable):
  125. """
  126. An IPv4 or IPv6 prefix.
  127. """
  128. form_field = IPFormField
  129. def __init__(self, min_prefix_length=None, max_prefix_length=None, *args, **kwargs):
  130. super().__init__(*args, **kwargs)
  131. self.field_attrs['validators'] = list()
  132. # Optional minimum/maximum prefix lengths
  133. if min_prefix_length is not None:
  134. self.field_attrs['validators'].append(
  135. MinPrefixLengthValidator(min_prefix_length)
  136. )
  137. if max_prefix_length is not None:
  138. self.field_attrs['validators'].append(
  139. MaxPrefixLengthValidator(max_prefix_length)
  140. )
  141. #
  142. # Scripts
  143. #
  144. class BaseScript:
  145. """
  146. Base model for custom scripts. User classes should inherit from this model if they want to extend Script
  147. functionality for use in other subclasses.
  148. """
  149. class Meta:
  150. pass
  151. def __init__(self):
  152. # Initiate the log
  153. self.log = []
  154. # Grab some info about the script
  155. self.filename = inspect.getfile(self.__class__)
  156. self.source = inspect.getsource(self.__class__)
  157. def __str__(self):
  158. return getattr(self.Meta, 'name', self.__class__.__name__)
  159. def _get_vars(self):
  160. vars = OrderedDict()
  161. # Infer order from Meta.field_order (Python 3.5 and lower)
  162. field_order = getattr(self.Meta, 'field_order', [])
  163. for name in field_order:
  164. vars[name] = getattr(self, name)
  165. # Default to order of declaration on class
  166. for name, attr in self.__class__.__dict__.items():
  167. if name not in vars and issubclass(attr.__class__, ScriptVariable):
  168. vars[name] = attr
  169. return vars
  170. def run(self, data):
  171. raise NotImplementedError("The script must define a run() method.")
  172. def as_form(self, data=None, files=None):
  173. """
  174. Return a Django form suitable for populating the context data required to run this Script.
  175. """
  176. vars = self._get_vars()
  177. form = ScriptForm(vars, data, files)
  178. return form
  179. # Logging
  180. def log_debug(self, message):
  181. self.log.append((LOG_DEFAULT, message))
  182. def log_success(self, message):
  183. self.log.append((LOG_SUCCESS, message))
  184. def log_info(self, message):
  185. self.log.append((LOG_INFO, message))
  186. def log_warning(self, message):
  187. self.log.append((LOG_WARNING, message))
  188. def log_failure(self, message):
  189. self.log.append((LOG_FAILURE, message))
  190. # Convenience functions
  191. def load_yaml(self, filename):
  192. """
  193. Return data from a YAML file
  194. """
  195. file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
  196. with open(file_path, 'r') as datafile:
  197. data = yaml.load(datafile)
  198. return data
  199. def load_json(self, filename):
  200. """
  201. Return data from a JSON file
  202. """
  203. file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
  204. with open(file_path, 'r') as datafile:
  205. data = json.load(datafile)
  206. return data
  207. class Script(BaseScript):
  208. """
  209. Classes which inherit this model will appear in the list of available scripts.
  210. """
  211. pass
  212. #
  213. # Functions
  214. #
  215. def is_script(obj):
  216. """
  217. Returns True if the object is a Script.
  218. """
  219. try:
  220. return issubclass(obj, Script) and obj != Script
  221. except TypeError:
  222. return False
  223. def is_variable(obj):
  224. """
  225. Returns True if the object is a ScriptVariable.
  226. """
  227. return isinstance(obj, ScriptVariable)
  228. def run_script(script, data, files, commit=True):
  229. """
  230. A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It
  231. exists outside of the Script class to ensure it cannot be overridden by a script author.
  232. """
  233. output = None
  234. start_time = None
  235. end_time = None
  236. # Add files to form data
  237. for field_name, fileobj in files.items():
  238. data[field_name] = fileobj
  239. try:
  240. with transaction.atomic():
  241. start_time = time.time()
  242. output = script.run(data)
  243. end_time = time.time()
  244. if not commit:
  245. raise AbortTransaction()
  246. except AbortTransaction:
  247. pass
  248. except Exception as e:
  249. stacktrace = traceback.format_exc()
  250. script.log_failure(
  251. "An exception occurred: `{}: {}`\n```\n{}\n```".format(type(e).__name__, e, stacktrace)
  252. )
  253. commit = False
  254. finally:
  255. if not commit:
  256. # Delete all pending changelog entries
  257. purge_changelog.send(Script)
  258. script.log_info(
  259. "Database changes have been reverted automatically."
  260. )
  261. # Calculate execution time
  262. if end_time is not None:
  263. execution_time = end_time - start_time
  264. else:
  265. execution_time = None
  266. return output, execution_time
  267. def get_scripts():
  268. scripts = OrderedDict()
  269. # Iterate through all modules within the reports path. These are the user-created files in which reports are
  270. # defined.
  271. for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
  272. module = importer.find_module(module_name).load_module(module_name)
  273. if hasattr(module, 'name'):
  274. module_name = module.name
  275. module_scripts = OrderedDict()
  276. for name, cls in inspect.getmembers(module, is_script):
  277. module_scripts[name] = cls
  278. scripts[module_name] = module_scripts
  279. return scripts