data_backends.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import logging
  2. import os
  3. import re
  4. import tempfile
  5. from contextlib import contextmanager
  6. from pathlib import Path
  7. from urllib.parse import urlparse
  8. from django import forms
  9. from django.conf import settings
  10. from django.utils.translation import gettext as _
  11. from netbox.data_backends import DataBackend
  12. from netbox.utils import register_data_backend
  13. from .exceptions import SyncError
  14. __all__ = (
  15. 'GitBackend',
  16. 'LocalBackend',
  17. 'S3Backend',
  18. )
  19. logger = logging.getLogger('netbox.data_backends')
  20. @register_data_backend()
  21. class LocalBackend(DataBackend):
  22. name = 'local'
  23. label = _('Local')
  24. is_local = True
  25. @contextmanager
  26. def fetch(self):
  27. logger.debug(f"Data source type is local; skipping fetch")
  28. local_path = urlparse(self.url).path # Strip file:// scheme
  29. yield local_path
  30. @register_data_backend()
  31. class GitBackend(DataBackend):
  32. name = 'git'
  33. label = 'Git'
  34. parameters = {
  35. 'username': forms.CharField(
  36. required=False,
  37. label=_('Username'),
  38. widget=forms.TextInput(attrs={'class': 'form-control'}),
  39. help_text=_("Only used for cloning with HTTP(S)"),
  40. ),
  41. 'password': forms.CharField(
  42. required=False,
  43. label=_('Password'),
  44. widget=forms.TextInput(attrs={'class': 'form-control'}),
  45. help_text=_("Only used for cloning with HTTP(S)"),
  46. ),
  47. 'branch': forms.CharField(
  48. required=False,
  49. label=_('Branch'),
  50. widget=forms.TextInput(attrs={'class': 'form-control'})
  51. )
  52. }
  53. sensitive_parameters = ['password']
  54. def init_config(self):
  55. from dulwich.config import ConfigDict
  56. # Initialize backend config
  57. config = ConfigDict()
  58. # Apply HTTP proxy (if configured)
  59. if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'):
  60. if proxy := settings.HTTP_PROXIES.get(self.url_scheme):
  61. config.set("http", "proxy", proxy)
  62. return config
  63. @contextmanager
  64. def fetch(self):
  65. from dulwich import porcelain
  66. local_path = tempfile.TemporaryDirectory()
  67. clone_args = {
  68. "branch": self.params.get('branch'),
  69. "config": self.config,
  70. "depth": 1,
  71. "errstream": porcelain.NoneStream(),
  72. "quiet": True,
  73. }
  74. if self.url_scheme in ('http', 'https'):
  75. if self.params.get('username'):
  76. clone_args.update(
  77. {
  78. "username": self.params.get('username'),
  79. "password": self.params.get('password'),
  80. }
  81. )
  82. logger.debug(f"Cloning git repo: {self.url}")
  83. try:
  84. porcelain.clone(self.url, local_path.name, **clone_args)
  85. except BaseException as e:
  86. raise SyncError(f"Fetching remote data failed ({type(e).__name__}): {e}")
  87. yield local_path.name
  88. local_path.cleanup()
  89. @register_data_backend()
  90. class S3Backend(DataBackend):
  91. name = 'amazon-s3'
  92. label = 'Amazon S3'
  93. parameters = {
  94. 'aws_access_key_id': forms.CharField(
  95. label=_('AWS access key ID'),
  96. widget=forms.TextInput(attrs={'class': 'form-control'})
  97. ),
  98. 'aws_secret_access_key': forms.CharField(
  99. label=_('AWS secret access key'),
  100. widget=forms.TextInput(attrs={'class': 'form-control'})
  101. ),
  102. }
  103. sensitive_parameters = ['aws_secret_access_key']
  104. REGION_REGEX = r's3\.([a-z0-9-]+)\.amazonaws\.com'
  105. def init_config(self):
  106. from botocore.config import Config as Boto3Config
  107. # Initialize backend config
  108. return Boto3Config(
  109. proxies=settings.HTTP_PROXIES,
  110. )
  111. @contextmanager
  112. def fetch(self):
  113. import boto3
  114. local_path = tempfile.TemporaryDirectory()
  115. # Initialize the S3 resource and bucket
  116. aws_access_key_id = self.params.get('aws_access_key_id')
  117. aws_secret_access_key = self.params.get('aws_secret_access_key')
  118. s3 = boto3.resource(
  119. 's3',
  120. region_name=self._region_name,
  121. aws_access_key_id=aws_access_key_id,
  122. aws_secret_access_key=aws_secret_access_key,
  123. config=self.config
  124. )
  125. bucket = s3.Bucket(self._bucket_name)
  126. # Download all files within the specified path
  127. for obj in bucket.objects.filter(Prefix=self._remote_path):
  128. local_filename = os.path.join(local_path.name, obj.key)
  129. # Build local path
  130. Path(os.path.dirname(local_filename)).mkdir(parents=True, exist_ok=True)
  131. bucket.download_file(obj.key, local_filename)
  132. yield local_path.name
  133. local_path.cleanup()
  134. @property
  135. def _region_name(self):
  136. domain = urlparse(self.url).netloc
  137. if m := re.match(self.REGION_REGEX, domain):
  138. return m.group(1)
  139. return None
  140. @property
  141. def _bucket_name(self):
  142. url_path = urlparse(self.url).path.lstrip('/')
  143. return url_path.split('/')[0]
  144. @property
  145. def _remote_path(self):
  146. url_path = urlparse(self.url).path.lstrip('/')
  147. if '/' in url_path:
  148. return url_path.split('/', 1)[1]
  149. return ''