data_backends.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import logging
  2. import os
  3. import re
  4. import tempfile
  5. from contextlib import contextmanager
  6. from pathlib import Path
  7. from urllib.parse import urlparse
  8. import boto3
  9. from botocore.config import Config as Boto3Config
  10. from django import forms
  11. from django.conf import settings
  12. from django.utils.translation import gettext as _
  13. from dulwich import porcelain
  14. from dulwich.config import StackedConfig
  15. from netbox.registry import registry
  16. from .choices import DataSourceTypeChoices
  17. from .exceptions import SyncError
  18. __all__ = (
  19. 'LocalBackend',
  20. 'GitBackend',
  21. 'S3Backend',
  22. )
  23. logger = logging.getLogger('netbox.data_backends')
  24. def register_backend(name):
  25. """
  26. Decorator for registering a DataBackend class.
  27. """
  28. def _wrapper(cls):
  29. registry['data_backends'][name] = cls
  30. return cls
  31. return _wrapper
  32. class DataBackend:
  33. parameters = {}
  34. def __init__(self, url, **kwargs):
  35. self.url = url
  36. self.params = kwargs
  37. @property
  38. def url_scheme(self):
  39. return urlparse(self.url).scheme.lower()
  40. @contextmanager
  41. def fetch(self):
  42. raise NotImplemented()
  43. @register_backend(DataSourceTypeChoices.LOCAL)
  44. class LocalBackend(DataBackend):
  45. @contextmanager
  46. def fetch(self):
  47. logger.debug(f"Data source type is local; skipping fetch")
  48. local_path = urlparse(self.url).path # Strip file:// scheme
  49. yield local_path
  50. @register_backend(DataSourceTypeChoices.GIT)
  51. class GitBackend(DataBackend):
  52. parameters = {
  53. 'username': forms.CharField(
  54. required=False,
  55. label=_('Username'),
  56. widget=forms.TextInput(attrs={'class': 'form-control'})
  57. ),
  58. 'password': forms.CharField(
  59. required=False,
  60. label=_('Password'),
  61. widget=forms.TextInput(attrs={'class': 'form-control'})
  62. ),
  63. 'branch': forms.CharField(
  64. required=False,
  65. label=_('Branch'),
  66. widget=forms.TextInput(attrs={'class': 'form-control'})
  67. )
  68. }
  69. @contextmanager
  70. def fetch(self):
  71. local_path = tempfile.TemporaryDirectory()
  72. username = self.params.get('username')
  73. password = self.params.get('password')
  74. branch = self.params.get('branch')
  75. config = StackedConfig.default()
  76. if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'):
  77. if proxy := settings.HTTP_PROXIES.get(self.url_scheme):
  78. config.set("http", "proxy", proxy)
  79. logger.debug(f"Cloning git repo: {self.url}")
  80. try:
  81. porcelain.clone(
  82. self.url, local_path.name, depth=1, branch=branch, username=username, password=password,
  83. config=config, quiet=True, errstream=porcelain.NoneStream()
  84. )
  85. except BaseException as e:
  86. raise SyncError(f"Fetching remote data failed ({type(e).__name__}): {e}")
  87. yield local_path.name
  88. local_path.cleanup()
  89. @register_backend(DataSourceTypeChoices.AMAZON_S3)
  90. class S3Backend(DataBackend):
  91. parameters = {
  92. 'aws_access_key_id': forms.CharField(
  93. label=_('AWS access key ID'),
  94. widget=forms.TextInput(attrs={'class': 'form-control'})
  95. ),
  96. 'aws_secret_access_key': forms.CharField(
  97. label=_('AWS secret access key'),
  98. widget=forms.TextInput(attrs={'class': 'form-control'})
  99. ),
  100. }
  101. REGION_REGEX = r's3\.([a-z0-9-]+)\.amazonaws\.com'
  102. @contextmanager
  103. def fetch(self):
  104. local_path = tempfile.TemporaryDirectory()
  105. # Build the S3 configuration
  106. s3_config = Boto3Config(
  107. proxies=settings.HTTP_PROXIES,
  108. )
  109. # Initialize the S3 resource and bucket
  110. aws_access_key_id = self.params.get('aws_access_key_id')
  111. aws_secret_access_key = self.params.get('aws_secret_access_key')
  112. s3 = boto3.resource(
  113. 's3',
  114. region_name=self._region_name,
  115. aws_access_key_id=aws_access_key_id,
  116. aws_secret_access_key=aws_secret_access_key,
  117. config=s3_config
  118. )
  119. bucket = s3.Bucket(self._bucket_name)
  120. # Download all files within the specified path
  121. for obj in bucket.objects.filter(Prefix=self._remote_path):
  122. local_filename = os.path.join(local_path.name, obj.key)
  123. # Build local path
  124. Path(os.path.dirname(local_filename)).mkdir(parents=True, exist_ok=True)
  125. bucket.download_file(obj.key, local_filename)
  126. yield local_path.name
  127. local_path.cleanup()
  128. @property
  129. def _region_name(self):
  130. domain = urlparse(self.url).netloc
  131. if m := re.match(self.REGION_REGEX, domain):
  132. return m.group(1)
  133. return None
  134. @property
  135. def _bucket_name(self):
  136. url_path = urlparse(self.url).path.lstrip('/')
  137. return url_path.split('/')[0]
  138. @property
  139. def _remote_path(self):
  140. url_path = urlparse(self.url).path.lstrip('/')
  141. if '/' in url_path:
  142. return url_path.split('/', 1)[1]
  143. return ''