data_backends.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import logging
  2. import os
  3. import re
  4. import tempfile
  5. from contextlib import contextmanager
  6. from pathlib import Path
  7. from urllib.parse import urlparse
  8. from django import forms
  9. from django.core.exceptions import ImproperlyConfigured
  10. from django.utils.translation import gettext as _
  11. from netbox.data_backends import DataBackend
  12. from netbox.utils import register_data_backend
  13. from utilities.constants import HTTP_PROXY_SUPPORTED_SCHEMAS, HTTP_PROXY_SUPPORTED_SOCK_SCHEMAS
  14. from utilities.proxy import resolve_proxies
  15. from utilities.socks import ProxyPoolManager
  16. from .exceptions import SyncError
  17. __all__ = (
  18. 'GitBackend',
  19. 'LocalBackend',
  20. 'S3Backend',
  21. )
  22. logger = logging.getLogger('netbox.data_backends')
  23. @register_data_backend()
  24. class LocalBackend(DataBackend):
  25. name = 'local'
  26. label = _('Local')
  27. is_local = True
  28. @contextmanager
  29. def fetch(self):
  30. logger.debug("Data source type is local; skipping fetch")
  31. local_path = urlparse(self.url).path # Strip file:// scheme
  32. yield local_path
  33. @register_data_backend()
  34. class GitBackend(DataBackend):
  35. name = 'git'
  36. label = 'Git'
  37. parameters = {
  38. 'username': forms.CharField(
  39. required=False,
  40. label=_('Username'),
  41. widget=forms.TextInput(attrs={'class': 'form-control'}),
  42. help_text=_("Only used for cloning with HTTP(S)"),
  43. ),
  44. 'password': forms.CharField(
  45. required=False,
  46. label=_('Password'),
  47. widget=forms.TextInput(attrs={'class': 'form-control'}),
  48. help_text=_("Only used for cloning with HTTP(S)"),
  49. ),
  50. 'branch': forms.CharField(
  51. required=False,
  52. label=_('Branch'),
  53. widget=forms.TextInput(attrs={'class': 'form-control'})
  54. )
  55. }
  56. sensitive_parameters = ['password']
  57. def init_config(self):
  58. from dulwich.config import ConfigDict
  59. # Initialize backend config
  60. config = ConfigDict()
  61. self.socks_proxy = None
  62. # Apply HTTP proxy (if configured)
  63. proxies = resolve_proxies(url=self.url, context={'client': self}) or {}
  64. if proxy := proxies.get(self.url_scheme):
  65. if urlparse(proxy).scheme not in HTTP_PROXY_SUPPORTED_SCHEMAS:
  66. raise ImproperlyConfigured(f"Unsupported Git DataSource proxy scheme: {urlparse(proxy).scheme}")
  67. if self.url_scheme in ('http', 'https'):
  68. config.set("http", "proxy", proxy)
  69. if urlparse(proxy).scheme in HTTP_PROXY_SUPPORTED_SOCK_SCHEMAS:
  70. self.socks_proxy = proxy
  71. return config
  72. @contextmanager
  73. def fetch(self):
  74. from dulwich import porcelain
  75. local_path = tempfile.TemporaryDirectory()
  76. clone_args = {
  77. "branch": self.params.get('branch'),
  78. "config": self.config,
  79. "errstream": porcelain.NoneStream(),
  80. }
  81. # check if using socks for proxy - if so need to use custom pool_manager
  82. if self.socks_proxy:
  83. clone_args['pool_manager'] = ProxyPoolManager(self.socks_proxy)
  84. if self.url_scheme in ('http', 'https'):
  85. if self.params.get('username'):
  86. clone_args.update(
  87. {
  88. "username": self.params.get('username'),
  89. "password": self.params.get('password'),
  90. }
  91. )
  92. if self.url_scheme:
  93. clone_args["quiet"] = True
  94. clone_args["depth"] = 1
  95. logger.debug(f"Cloning git repo: {self.url}")
  96. try:
  97. porcelain.clone(self.url, local_path.name, **clone_args)
  98. except BaseException as e:
  99. raise SyncError(_("Fetching remote data failed ({name}): {error}").format(name=type(e).__name__, error=e))
  100. yield local_path.name
  101. local_path.cleanup()
  102. @register_data_backend()
  103. class S3Backend(DataBackend):
  104. name = 'amazon-s3'
  105. label = 'Amazon S3'
  106. parameters = {
  107. 'aws_access_key_id': forms.CharField(
  108. label=_('AWS access key ID'),
  109. widget=forms.TextInput(attrs={'class': 'form-control'})
  110. ),
  111. 'aws_secret_access_key': forms.CharField(
  112. label=_('AWS secret access key'),
  113. widget=forms.TextInput(attrs={'class': 'form-control'})
  114. ),
  115. }
  116. sensitive_parameters = ['aws_secret_access_key']
  117. REGION_REGEX = r's3\.([a-z0-9-]+)\.amazonaws\.com'
  118. def init_config(self):
  119. from botocore.config import Config as Boto3Config
  120. # Initialize backend config
  121. return Boto3Config(
  122. proxies=resolve_proxies(url=self.url, context={'client': self}),
  123. )
  124. @contextmanager
  125. def fetch(self):
  126. import boto3
  127. local_path = tempfile.TemporaryDirectory()
  128. # Initialize the S3 resource and bucket
  129. aws_access_key_id = self.params.get('aws_access_key_id')
  130. aws_secret_access_key = self.params.get('aws_secret_access_key')
  131. s3 = boto3.resource(
  132. 's3',
  133. region_name=self._region_name,
  134. aws_access_key_id=aws_access_key_id,
  135. aws_secret_access_key=aws_secret_access_key,
  136. config=self.config,
  137. endpoint_url=self._endpoint_url
  138. )
  139. bucket = s3.Bucket(self._bucket_name)
  140. # Download all files within the specified path
  141. for obj in bucket.objects.filter(Prefix=self._remote_path):
  142. local_filename = os.path.join(local_path.name, obj.key)
  143. # Build local path
  144. Path(os.path.dirname(local_filename)).mkdir(parents=True, exist_ok=True)
  145. bucket.download_file(obj.key, local_filename)
  146. yield local_path.name
  147. local_path.cleanup()
  148. @property
  149. def _region_name(self):
  150. domain = urlparse(self.url).netloc
  151. if m := re.match(self.REGION_REGEX, domain):
  152. return m.group(1)
  153. return None
  154. @property
  155. def _bucket_name(self):
  156. url_path = urlparse(self.url).path.lstrip('/')
  157. return url_path.split('/')[0]
  158. @property
  159. def _endpoint_url(self):
  160. url_path = urlparse(self.url)
  161. return url_path._replace(params="", fragment="", query="", path="").geturl()
  162. @property
  163. def _remote_path(self):
  164. url_path = urlparse(self.url).path.lstrip('/')
  165. if '/' in url_path:
  166. return url_path.split('/', 1)[1]
  167. return ''