scripts.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. import inspect
  2. import json
  3. import os
  4. import pkgutil
  5. import time
  6. import traceback
  7. from collections import OrderedDict
  8. import yaml
  9. from django import forms
  10. from django.conf import settings
  11. from django.core.validators import RegexValidator
  12. from django.db import transaction
  13. from mptt.forms import TreeNodeChoiceField, TreeNodeMultipleChoiceField
  14. from mptt.models import MPTTModel
  15. from ipam.formfields import IPAddressFormField, IPNetworkFormField
  16. from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator
  17. from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_SUCCESS, LOG_WARNING
  18. from utilities.exceptions import AbortTransaction
  19. from utilities.forms import DynamicModelChoiceField, DynamicModelMultipleChoiceField
  20. from .forms import ScriptForm
  21. from .signals import purge_changelog
  22. __all__ = [
  23. 'BaseScript',
  24. 'BooleanVar',
  25. 'ChoiceVar',
  26. 'FileVar',
  27. 'IntegerVar',
  28. 'IPAddressVar',
  29. 'IPAddressWithMaskVar',
  30. 'IPNetworkVar',
  31. 'MultiObjectVar',
  32. 'ObjectVar',
  33. 'Script',
  34. 'StringVar',
  35. 'TextVar',
  36. ]
  37. #
  38. # Script variables
  39. #
  40. class ScriptVariable:
  41. """
  42. Base model for script variables
  43. """
  44. form_field = forms.CharField
  45. def __init__(self, label='', description='', default=None, required=True, widget=None):
  46. # Initialize field attributes
  47. if not hasattr(self, 'field_attrs'):
  48. self.field_attrs = {}
  49. if label:
  50. self.field_attrs['label'] = label
  51. if description:
  52. self.field_attrs['help_text'] = description
  53. if default:
  54. self.field_attrs['initial'] = default
  55. if widget:
  56. self.field_attrs['widget'] = widget
  57. self.field_attrs['required'] = required
  58. def as_field(self):
  59. """
  60. Render the variable as a Django form field.
  61. """
  62. form_field = self.form_field(**self.field_attrs)
  63. if not isinstance(form_field.widget, forms.CheckboxInput):
  64. if form_field.widget.attrs and 'class' in form_field.widget.attrs.keys():
  65. form_field.widget.attrs['class'] += ' form-control'
  66. else:
  67. form_field.widget.attrs['class'] = 'form-control'
  68. return form_field
  69. class StringVar(ScriptVariable):
  70. """
  71. Character string representation. Can enforce minimum/maximum length and/or regex validation.
  72. """
  73. def __init__(self, min_length=None, max_length=None, regex=None, *args, **kwargs):
  74. super().__init__(*args, **kwargs)
  75. # Optional minimum/maximum lengths
  76. if min_length:
  77. self.field_attrs['min_length'] = min_length
  78. if max_length:
  79. self.field_attrs['max_length'] = max_length
  80. # Optional regular expression validation
  81. if regex:
  82. self.field_attrs['validators'] = [
  83. RegexValidator(
  84. regex=regex,
  85. message='Invalid value. Must match regex: {}'.format(regex),
  86. code='invalid'
  87. )
  88. ]
  89. class TextVar(ScriptVariable):
  90. """
  91. Free-form text data. Renders as a <textarea>.
  92. """
  93. form_field = forms.CharField
  94. def __init__(self, *args, **kwargs):
  95. super().__init__(*args, **kwargs)
  96. self.field_attrs['widget'] = forms.Textarea
  97. class IntegerVar(ScriptVariable):
  98. """
  99. Integer representation. Can enforce minimum/maximum values.
  100. """
  101. form_field = forms.IntegerField
  102. def __init__(self, min_value=None, max_value=None, *args, **kwargs):
  103. super().__init__(*args, **kwargs)
  104. # Optional minimum/maximum values
  105. if min_value:
  106. self.field_attrs['min_value'] = min_value
  107. if max_value:
  108. self.field_attrs['max_value'] = max_value
  109. class BooleanVar(ScriptVariable):
  110. """
  111. Boolean representation (true/false). Renders as a checkbox.
  112. """
  113. form_field = forms.BooleanField
  114. def __init__(self, *args, **kwargs):
  115. super().__init__(*args, **kwargs)
  116. # Boolean fields cannot be required
  117. self.field_attrs['required'] = False
  118. class ChoiceVar(ScriptVariable):
  119. """
  120. Select one of several predefined static choices, passed as a list of two-tuples. Example:
  121. color = ChoiceVar(
  122. choices=(
  123. ('#ff0000', 'Red'),
  124. ('#00ff00', 'Green'),
  125. ('#0000ff', 'Blue')
  126. )
  127. )
  128. """
  129. form_field = forms.ChoiceField
  130. def __init__(self, choices, *args, **kwargs):
  131. super().__init__(*args, **kwargs)
  132. # Set field choices
  133. self.field_attrs['choices'] = choices
  134. class ObjectVar(ScriptVariable):
  135. """
  136. NetBox object representation. The provided QuerySet will determine the choices available.
  137. """
  138. form_field = forms.ModelChoiceField
  139. def __init__(self, queryset, *args, **kwargs):
  140. super().__init__(*args, **kwargs)
  141. # Queryset for field choices
  142. self.field_attrs['queryset'] = queryset
  143. # Update form field for MPTT (nested) objects
  144. if issubclass(queryset.model, MPTTModel):
  145. self.form_field = TreeNodeChoiceField
  146. class MultiObjectVar(ScriptVariable):
  147. """
  148. Like ObjectVar, but can represent one or more objects.
  149. """
  150. form_field = forms.ModelMultipleChoiceField
  151. def __init__(self, queryset, *args, **kwargs):
  152. super().__init__(*args, **kwargs)
  153. # Queryset for field choices
  154. self.field_attrs['queryset'] = queryset
  155. # Update form field for MPTT (nested) objects
  156. if issubclass(queryset.model, MPTTModel):
  157. self.form_field = TreeNodeMultipleChoiceField
  158. class DynamicObjectVar(ObjectVar):
  159. """
  160. A dynamic netbox object variable. APISelect will determine the available choices
  161. """
  162. form_field = DynamicModelChoiceField
  163. class DynamicMultiObjectVar(MultiObjectVar):
  164. """
  165. A multiple choice version of DynamicObjectVar
  166. """
  167. form_field = DynamicModelMultipleChoiceField
  168. class FileVar(ScriptVariable):
  169. """
  170. An uploaded file.
  171. """
  172. form_field = forms.FileField
  173. class IPAddressVar(ScriptVariable):
  174. """
  175. An IPv4 or IPv6 address without a mask.
  176. """
  177. form_field = IPAddressFormField
  178. class IPAddressWithMaskVar(ScriptVariable):
  179. """
  180. An IPv4 or IPv6 address with a mask.
  181. """
  182. form_field = IPNetworkFormField
  183. class IPNetworkVar(ScriptVariable):
  184. """
  185. An IPv4 or IPv6 prefix.
  186. """
  187. form_field = IPNetworkFormField
  188. def __init__(self, min_prefix_length=None, max_prefix_length=None, *args, **kwargs):
  189. super().__init__(*args, **kwargs)
  190. # Set prefix validator and optional minimum/maximum prefix lengths
  191. self.field_attrs['validators'] = [prefix_validator]
  192. if min_prefix_length is not None:
  193. self.field_attrs['validators'].append(
  194. MinPrefixLengthValidator(min_prefix_length)
  195. )
  196. if max_prefix_length is not None:
  197. self.field_attrs['validators'].append(
  198. MaxPrefixLengthValidator(max_prefix_length)
  199. )
  200. #
  201. # Scripts
  202. #
  203. class BaseScript:
  204. """
  205. Base model for custom scripts. User classes should inherit from this model if they want to extend Script
  206. functionality for use in other subclasses.
  207. """
  208. class Meta:
  209. pass
  210. def __init__(self):
  211. # Initiate the log
  212. self.log = []
  213. # Declare the placeholder for the current request
  214. self.request = None
  215. # Grab some info about the script
  216. self.filename = inspect.getfile(self.__class__)
  217. self.source = inspect.getsource(self.__class__)
  218. def __str__(self):
  219. return getattr(self.Meta, 'name', self.__class__.__name__)
  220. @classmethod
  221. def module(cls):
  222. return cls.__module__
  223. @classmethod
  224. def _get_vars(cls):
  225. vars = OrderedDict()
  226. # Infer order from Meta.field_order (Python 3.5 and lower)
  227. field_order = getattr(cls.Meta, 'field_order', [])
  228. for name in field_order:
  229. vars[name] = getattr(cls, name)
  230. # Default to order of declaration on class
  231. for name, attr in cls.__dict__.items():
  232. if name not in vars and issubclass(attr.__class__, ScriptVariable):
  233. vars[name] = attr
  234. return vars
  235. def run(self, data):
  236. raise NotImplementedError("The script must define a run() method.")
  237. def as_form(self, data=None, files=None, initial=None):
  238. """
  239. Return a Django form suitable for populating the context data required to run this Script.
  240. """
  241. vars = self._get_vars()
  242. form = ScriptForm(vars, data, files, initial=initial, commit_default=getattr(self.Meta, 'commit_default', True))
  243. return form
  244. # Logging
  245. def log_debug(self, message):
  246. self.log.append((LOG_DEFAULT, message))
  247. def log_success(self, message):
  248. self.log.append((LOG_SUCCESS, message))
  249. def log_info(self, message):
  250. self.log.append((LOG_INFO, message))
  251. def log_warning(self, message):
  252. self.log.append((LOG_WARNING, message))
  253. def log_failure(self, message):
  254. self.log.append((LOG_FAILURE, message))
  255. # Convenience functions
  256. def load_yaml(self, filename):
  257. """
  258. Return data from a YAML file
  259. """
  260. file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
  261. with open(file_path, 'r') as datafile:
  262. data = yaml.load(datafile)
  263. return data
  264. def load_json(self, filename):
  265. """
  266. Return data from a JSON file
  267. """
  268. file_path = os.path.join(settings.SCRIPTS_ROOT, filename)
  269. with open(file_path, 'r') as datafile:
  270. data = json.load(datafile)
  271. return data
  272. class Script(BaseScript):
  273. """
  274. Classes which inherit this model will appear in the list of available scripts.
  275. """
  276. pass
  277. #
  278. # Functions
  279. #
  280. def is_script(obj):
  281. """
  282. Returns True if the object is a Script.
  283. """
  284. try:
  285. return issubclass(obj, Script) and obj != Script
  286. except TypeError:
  287. return False
  288. def is_variable(obj):
  289. """
  290. Returns True if the object is a ScriptVariable.
  291. """
  292. return isinstance(obj, ScriptVariable)
  293. def run_script(script, data, request, commit=True):
  294. """
  295. A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It
  296. exists outside of the Script class to ensure it cannot be overridden by a script author.
  297. """
  298. output = None
  299. start_time = None
  300. end_time = None
  301. # Add files to form data
  302. files = request.FILES
  303. for field_name, fileobj in files.items():
  304. data[field_name] = fileobj
  305. # Add the current request as a property of the script
  306. script.request = request
  307. try:
  308. with transaction.atomic():
  309. start_time = time.time()
  310. output = script.run(data)
  311. end_time = time.time()
  312. if not commit:
  313. raise AbortTransaction()
  314. except AbortTransaction:
  315. pass
  316. except Exception as e:
  317. stacktrace = traceback.format_exc()
  318. script.log_failure(
  319. "An exception occurred: `{}: {}`\n```\n{}\n```".format(type(e).__name__, e, stacktrace)
  320. )
  321. commit = False
  322. finally:
  323. if not commit:
  324. # Delete all pending changelog entries
  325. purge_changelog.send(Script)
  326. script.log_info(
  327. "Database changes have been reverted automatically."
  328. )
  329. # Calculate execution time
  330. if end_time is not None:
  331. execution_time = end_time - start_time
  332. else:
  333. execution_time = None
  334. return output, execution_time
  335. def get_scripts(use_names=False):
  336. """
  337. Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
  338. defined name in place of the actual module name.
  339. """
  340. scripts = OrderedDict()
  341. # Iterate through all modules within the reports path. These are the user-created files in which reports are
  342. # defined.
  343. for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
  344. module = importer.find_module(module_name).load_module(module_name)
  345. if use_names and hasattr(module, 'name'):
  346. module_name = module.name
  347. module_scripts = OrderedDict()
  348. for name, cls in inspect.getmembers(module, is_script):
  349. module_scripts[name] = cls
  350. scripts[module_name] = module_scripts
  351. return scripts
  352. def get_script(module_name, script_name):
  353. """
  354. Retrieve a script class by module and name. Returns None if the script does not exist.
  355. """
  356. scripts = get_scripts()
  357. module = scripts.get(module_name)
  358. if module:
  359. return module.get(script_name)