api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import logging
  2. from collections import OrderedDict
  3. import pytz
  4. from django.conf import settings
  5. from django.contrib.contenttypes.models import ContentType
  6. from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist, PermissionDenied
  7. from django.db import transaction
  8. from django.db.models import ManyToManyField, ProtectedError
  9. from django.urls import reverse
  10. from rest_framework import serializers
  11. from rest_framework.exceptions import APIException, ValidationError
  12. from rest_framework.permissions import BasePermission
  13. from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
  14. from rest_framework.response import Response
  15. from rest_framework.routers import DefaultRouter
  16. from rest_framework.viewsets import ModelViewSet as _ModelViewSet
  17. from .utils import dict_to_filter_params, dynamic_import
  18. HTTP_ACTIONS = {
  19. 'GET': 'view',
  20. 'OPTIONS': None,
  21. 'HEAD': 'view',
  22. 'POST': 'add',
  23. 'PUT': 'change',
  24. 'PATCH': 'change',
  25. 'DELETE': 'delete',
  26. }
  27. class ServiceUnavailable(APIException):
  28. status_code = 503
  29. default_detail = "Service temporarily unavailable, please try again later."
  30. class SerializerNotFound(Exception):
  31. pass
  32. def get_serializer_for_model(model, prefix=''):
  33. """
  34. Dynamically resolve and return the appropriate serializer for a model.
  35. """
  36. app_name, model_name = model._meta.label.split('.')
  37. # Serializers for Django's auth models are in the users app
  38. if app_name == 'auth':
  39. app_name = 'users'
  40. serializer_name = f'{app_name}.api.serializers.{prefix}{model_name}Serializer'
  41. try:
  42. return dynamic_import(serializer_name)
  43. except AttributeError:
  44. raise SerializerNotFound(
  45. "Could not determine serializer for {}.{} with prefix '{}'".format(app_name, model_name, prefix)
  46. )
  47. def is_api_request(request):
  48. """
  49. Return True of the request is being made via the REST API.
  50. """
  51. api_path = reverse('api-root')
  52. return request.path_info.startswith(api_path)
  53. #
  54. # Authentication
  55. #
  56. class IsAuthenticatedOrLoginNotRequired(BasePermission):
  57. """
  58. Returns True if the user is authenticated or LOGIN_REQUIRED is False.
  59. """
  60. def has_permission(self, request, view):
  61. if not settings.LOGIN_REQUIRED:
  62. return True
  63. return request.user.is_authenticated
  64. #
  65. # Fields
  66. #
  67. class ChoiceField(serializers.Field):
  68. """
  69. Represent a ChoiceField as {'value': <DB value>, 'label': <string>}. Accepts a single value on write.
  70. :param choices: An iterable of choices in the form (value, key).
  71. :param allow_blank: Allow blank values in addition to the listed choices.
  72. """
  73. def __init__(self, choices, allow_blank=False, **kwargs):
  74. self.choiceset = choices
  75. self.allow_blank = allow_blank
  76. self._choices = dict()
  77. # Unpack grouped choices
  78. for k, v in choices:
  79. if type(v) in [list, tuple]:
  80. for k2, v2 in v:
  81. self._choices[k2] = v2
  82. else:
  83. self._choices[k] = v
  84. super().__init__(**kwargs)
  85. def validate_empty_values(self, data):
  86. # Convert null to an empty string unless allow_null == True
  87. if data is None:
  88. if self.allow_null:
  89. return True, None
  90. else:
  91. data = ''
  92. return super().validate_empty_values(data)
  93. def to_representation(self, obj):
  94. if obj is '':
  95. return None
  96. return OrderedDict([
  97. ('value', obj),
  98. ('label', self._choices[obj])
  99. ])
  100. def to_internal_value(self, data):
  101. if data is '':
  102. if self.allow_blank:
  103. return data
  104. raise ValidationError("This field may not be blank.")
  105. # Provide an explicit error message if the request is trying to write a dict or list
  106. if isinstance(data, (dict, list)):
  107. raise ValidationError('Value must be passed directly (e.g. "foo": 123); do not use a dictionary or list.')
  108. # Check for string representations of boolean/integer values
  109. if hasattr(data, 'lower'):
  110. if data.lower() == 'true':
  111. data = True
  112. elif data.lower() == 'false':
  113. data = False
  114. else:
  115. try:
  116. data = int(data)
  117. except ValueError:
  118. pass
  119. try:
  120. if data in self._choices:
  121. return data
  122. except TypeError: # Input is an unhashable type
  123. pass
  124. raise ValidationError(f"{data} is not a valid choice.")
  125. @property
  126. def choices(self):
  127. return self._choices
  128. class ContentTypeField(RelatedField):
  129. """
  130. Represent a ContentType as '<app_label>.<model>'
  131. """
  132. default_error_messages = {
  133. "does_not_exist": "Invalid content type: {content_type}",
  134. "invalid": "Invalid value. Specify a content type as '<app_label>.<model_name>'.",
  135. }
  136. def to_internal_value(self, data):
  137. try:
  138. app_label, model = data.split('.')
  139. return ContentType.objects.get_by_natural_key(app_label=app_label, model=model)
  140. except ObjectDoesNotExist:
  141. self.fail('does_not_exist', content_type=data)
  142. except (TypeError, ValueError):
  143. self.fail('invalid')
  144. def to_representation(self, obj):
  145. return "{}.{}".format(obj.app_label, obj.model)
  146. class TimeZoneField(serializers.Field):
  147. """
  148. Represent a pytz time zone.
  149. """
  150. def to_representation(self, obj):
  151. return obj.zone if obj else None
  152. def to_internal_value(self, data):
  153. if not data:
  154. return ""
  155. if data not in pytz.common_timezones:
  156. raise ValidationError('Unknown time zone "{}" (see pytz.common_timezones for all options)'.format(data))
  157. return pytz.timezone(data)
  158. class SerializedPKRelatedField(PrimaryKeyRelatedField):
  159. """
  160. Extends PrimaryKeyRelatedField to return a serialized object on read. This is useful for representing related
  161. objects in a ManyToManyField while still allowing a set of primary keys to be written.
  162. """
  163. def __init__(self, serializer, **kwargs):
  164. self.serializer = serializer
  165. self.pk_field = kwargs.pop('pk_field', None)
  166. super().__init__(**kwargs)
  167. def to_representation(self, value):
  168. return self.serializer(value, context={'request': self.context['request']}).data
  169. #
  170. # Serializers
  171. #
  172. # TODO: We should probably take a fresh look at exactly what we're doing with this. There might be a more elegant
  173. # way to enforce model validation on the serializer.
  174. class ValidatedModelSerializer(serializers.ModelSerializer):
  175. """
  176. Extends the built-in ModelSerializer to enforce calling clean() on the associated model during validation.
  177. """
  178. def validate(self, data):
  179. # Remove custom fields data and tags (if any) prior to model validation
  180. attrs = data.copy()
  181. attrs.pop('custom_fields', None)
  182. attrs.pop('tags', None)
  183. # Skip ManyToManyFields
  184. for field in self.Meta.model._meta.get_fields():
  185. if isinstance(field, ManyToManyField):
  186. attrs.pop(field.name, None)
  187. # Run clean() on an instance of the model
  188. if self.instance is None:
  189. instance = self.Meta.model(**attrs)
  190. else:
  191. instance = self.instance
  192. for k, v in attrs.items():
  193. setattr(instance, k, v)
  194. instance.clean()
  195. instance.validate_unique()
  196. return data
  197. class WritableNestedSerializer(serializers.ModelSerializer):
  198. """
  199. Returns a nested representation of an object on read, but accepts only a primary key on write.
  200. """
  201. def to_internal_value(self, data):
  202. if data is None:
  203. return None
  204. # Dictionary of related object attributes
  205. if isinstance(data, dict):
  206. params = dict_to_filter_params(data)
  207. queryset = self.Meta.model.objects
  208. try:
  209. return queryset.get(**params)
  210. except ObjectDoesNotExist:
  211. raise ValidationError(
  212. "Related object not found using the provided attributes: {}".format(params)
  213. )
  214. except MultipleObjectsReturned:
  215. raise ValidationError(
  216. "Multiple objects match the provided attributes: {}".format(params)
  217. )
  218. except FieldError as e:
  219. raise ValidationError(e)
  220. # Integer PK of related object
  221. if isinstance(data, int):
  222. pk = data
  223. else:
  224. try:
  225. # PK might have been mistakenly passed as a string
  226. pk = int(data)
  227. except (TypeError, ValueError):
  228. raise ValidationError(
  229. "Related objects must be referenced by numeric ID or by dictionary of attributes. Received an "
  230. "unrecognized value: {}".format(data)
  231. )
  232. # Look up object by PK
  233. queryset = self.Meta.model.objects
  234. try:
  235. return queryset.get(pk=int(data))
  236. except ObjectDoesNotExist:
  237. raise ValidationError(
  238. "Related object not found using the provided numeric ID: {}".format(pk)
  239. )
  240. #
  241. # Viewsets
  242. #
  243. class ModelViewSet(_ModelViewSet):
  244. """
  245. Accept either a single object or a list of objects to create.
  246. """
  247. def get_serializer(self, *args, **kwargs):
  248. # If a list of objects has been provided, initialize the serializer with many=True
  249. if isinstance(kwargs.get('data', {}), list):
  250. kwargs['many'] = True
  251. return super().get_serializer(*args, **kwargs)
  252. def get_serializer_class(self):
  253. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  254. # If 'brief' has been passed as a query param, find and return the nested serializer for this model, if one
  255. # exists
  256. request = self.get_serializer_context()['request']
  257. if request.query_params.get('brief'):
  258. logger.debug("Request is for 'brief' format; initializing nested serializer")
  259. try:
  260. serializer = get_serializer_for_model(self.queryset.model, prefix='Nested')
  261. logger.debug(f"Using serializer {serializer}")
  262. return serializer
  263. except SerializerNotFound:
  264. pass
  265. # Fall back to the hard-coded serializer class
  266. logger.debug(f"Using serializer {self.serializer_class}")
  267. return self.serializer_class
  268. def initial(self, request, *args, **kwargs):
  269. super().initial(request, *args, **kwargs)
  270. if not request.user.is_authenticated:
  271. return
  272. # Restrict the view's QuerySet to allow only the permitted objects
  273. action = HTTP_ACTIONS[request.method]
  274. if action:
  275. self.queryset = self.queryset.restrict(request.user, action)
  276. def dispatch(self, request, *args, **kwargs):
  277. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  278. try:
  279. return super().dispatch(request, *args, **kwargs)
  280. except ProtectedError as e:
  281. protected_objects = list(e.protected_objects)
  282. msg = f'Unable to delete object. {len(protected_objects)} dependent objects were found: '
  283. msg += ', '.join([f'{obj} ({obj.pk})' for obj in protected_objects])
  284. logger.warning(msg)
  285. return self.finalize_response(
  286. request,
  287. Response({'detail': msg}, status=409),
  288. *args,
  289. **kwargs
  290. )
  291. def _validate_objects(self, instance):
  292. """
  293. Check that the provided instance or list of instances are matched by the current queryset. This confirms that
  294. any newly created or modified objects abide by the attributes granted by any applicable ObjectPermissions.
  295. """
  296. if type(instance) is list:
  297. # Check that all instances are still included in the view's queryset
  298. conforming_count = self.queryset.filter(pk__in=[obj.pk for obj in instance]).count()
  299. if conforming_count != len(instance):
  300. raise ObjectDoesNotExist
  301. else:
  302. # Check that the instance is matched by the view's queryset
  303. self.queryset.get(pk=instance.pk)
  304. def perform_create(self, serializer):
  305. model = self.queryset.model
  306. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  307. logger.info(f"Creating new {model._meta.verbose_name}")
  308. # Enforce object-level permissions on save()
  309. try:
  310. with transaction.atomic():
  311. instance = serializer.save()
  312. self._validate_objects(instance)
  313. except ObjectDoesNotExist:
  314. raise PermissionDenied()
  315. def perform_update(self, serializer):
  316. model = self.queryset.model
  317. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  318. logger.info(f"Updating {model._meta.verbose_name} {serializer.instance} (PK: {serializer.instance.pk})")
  319. # Enforce object-level permissions on save()
  320. try:
  321. with transaction.atomic():
  322. instance = serializer.save()
  323. self._validate_objects(instance)
  324. except ObjectDoesNotExist:
  325. raise PermissionDenied()
  326. def perform_destroy(self, instance):
  327. model = self.queryset.model
  328. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  329. logger.info(f"Deleting {model._meta.verbose_name} {instance} (PK: {instance.pk})")
  330. return super().perform_destroy(instance)
  331. #
  332. # Routers
  333. #
  334. class OrderedDefaultRouter(DefaultRouter):
  335. def get_api_root_view(self, api_urls=None):
  336. """
  337. Wrap DRF's DefaultRouter to return an alphabetized list of endpoints.
  338. """
  339. api_root_dict = OrderedDict()
  340. list_name = self.routes[0].name
  341. for prefix, viewset, basename in sorted(self.registry, key=lambda x: x[0]):
  342. api_root_dict[prefix] = list_name.format(basename=basename)
  343. return self.APIRootView.as_view(api_root_dict=api_root_dict)