schema.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import re
  2. import typing
  3. from collections import OrderedDict
  4. from drf_spectacular.contrib.django_filters import DjangoFilterExtension
  5. from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension, _SchemaType
  6. from drf_spectacular.openapi import AutoSchema
  7. from drf_spectacular.plumbing import (
  8. build_basic_type,
  9. build_choice_field,
  10. build_media_type_object,
  11. build_object_type,
  12. follow_field_source,
  13. get_doc,
  14. )
  15. from drf_spectacular.types import OpenApiTypes
  16. from drf_spectacular.utils import Direction
  17. from netbox.api.fields import ChoiceField
  18. from netbox.api.serializers import WritableNestedSerializer
  19. from netbox.api.viewsets import NetBoxModelViewSet
  20. # see netbox.api.routers.NetBoxRouter
  21. BULK_ACTIONS = ("bulk_destroy", "bulk_partial_update", "bulk_update")
  22. WRITABLE_ACTIONS = ("PATCH", "POST", "PUT")
  23. class NetBoxDjangoFilterExtension(DjangoFilterExtension):
  24. """
  25. Overrides drf-spectacular's DjangoFilterExtension to fix a regression in v0.29.0 where
  26. _get_model_field() incorrectly double-appends to_field_name when field_name already ends
  27. with that value (e.g. field_name='tags__slug', to_field_name='slug' produces the invalid
  28. path ['tags', 'slug', 'slug']). This caused hundreds of spurious warnings during schema
  29. generation for filters such as TagFilter, TenancyFilterSet.tenant, and OwnerFilterMixin.owner.
  30. See: https://github.com/netbox-community/netbox/issues/20787
  31. https://github.com/tfranzel/drf-spectacular/issues/1475
  32. """
  33. priority = 1
  34. def _get_model_field(self, filter_field, model):
  35. if not filter_field.field_name:
  36. return None
  37. path = filter_field.field_name.split('__')
  38. to_field_name = filter_field.extra.get('to_field_name')
  39. if to_field_name is not None and path[-1] != to_field_name:
  40. path.append(to_field_name)
  41. return follow_field_source(model, path, emit_warnings=False)
  42. class FixTimeZoneSerializerField(OpenApiSerializerFieldExtension):
  43. target_class = 'timezone_field.rest_framework.TimeZoneSerializerField'
  44. def map_serializer_field(self, auto_schema, direction):
  45. return build_basic_type(OpenApiTypes.STR)
  46. class ChoiceFieldFix(OpenApiSerializerFieldExtension):
  47. target_class = 'netbox.api.fields.ChoiceField'
  48. def map_serializer_field(self, auto_schema, direction):
  49. build_cf = build_choice_field(self.target)
  50. if direction == 'request':
  51. return build_cf
  52. if direction == "response":
  53. value = build_cf
  54. label = {
  55. **build_basic_type(OpenApiTypes.STR),
  56. "enum": list(OrderedDict.fromkeys(self.target.choices.values()))
  57. }
  58. return build_object_type(
  59. properties={
  60. "value": value,
  61. "label": label
  62. }
  63. )
  64. # TODO: This function should never implicitly/explicitly return `None`
  65. # The fallback should be well-defined (drf-spectacular expects request/response naming).
  66. return None
  67. def viewset_handles_bulk_create(view):
  68. """Check if view automatically provides list-based bulk create"""
  69. return isinstance(view, NetBoxModelViewSet)
  70. class NetBoxAutoSchema(AutoSchema):
  71. """
  72. Overrides to drf_spectacular.openapi.AutoSchema to fix following issues:
  73. 1. bulk serializers cause operation_id conflicts with non-bulk ones
  74. 2. bulk operations should specify a list
  75. 3. bulk operations don't have filter params
  76. 4. bulk operations don't have pagination
  77. 5. bulk delete should specify input
  78. """
  79. writable_serializers = {}
  80. @property
  81. def is_bulk_action(self):
  82. if hasattr(self.view, "action") and self.view.action in BULK_ACTIONS:
  83. return True
  84. return False
  85. def get_operation_id(self):
  86. """
  87. bulk serializers cause operation_id conflicts with non-bulk ones
  88. bulk operations cause id conflicts in spectacular resulting in numerous:
  89. Warning: operationId "xxx" has collisions [xxx]. "resolving with numeral suffixes"
  90. code is modified from drf_spectacular.openapi.AutoSchema.get_operation_id
  91. """
  92. if self.is_bulk_action:
  93. tokenized_path = self._tokenize_path()
  94. # replace dashes as they can be problematic later in code generation
  95. tokenized_path = [t.replace('-', '_') for t in tokenized_path]
  96. if self.method == 'GET' and self._is_list_view():
  97. # this shouldn't happen, but keeping it here to follow base code
  98. action = 'list'
  99. else:
  100. # action = self.method_mapping[self.method.lower()]
  101. # use bulk name so partial_update -> bulk_partial_update
  102. action = self.view.action.lower()
  103. if not tokenized_path:
  104. tokenized_path.append('root')
  105. if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
  106. tokenized_path.append('formatted')
  107. return '_'.join(tokenized_path + [action])
  108. # if not bulk - just return normal id
  109. return super().get_operation_id()
  110. def get_request_serializer(self) -> typing.Any:
  111. # bulk operations should specify a list
  112. serializer = super().get_request_serializer()
  113. if self.is_bulk_action:
  114. return type(serializer)(many=True)
  115. # handle mapping for Writable serializers - adapted from dansheps original code
  116. # for drf-yasg
  117. if serializer is not None and self.method in WRITABLE_ACTIONS:
  118. writable_class = self.get_writable_class(serializer)
  119. if writable_class is not None:
  120. if hasattr(serializer, "child"):
  121. child_serializer = self.get_writable_class(serializer.child)
  122. serializer = writable_class(context=serializer.context, child=child_serializer)
  123. else:
  124. serializer = writable_class(context=serializer.context)
  125. return serializer
  126. def get_response_serializers(self) -> typing.Any:
  127. # bulk operations should specify a list
  128. response_serializers = super().get_response_serializers()
  129. if self.is_bulk_action:
  130. return type(response_serializers)(many=True)
  131. return response_serializers
  132. def _get_request_for_media_type(self, serializer, direction='request'):
  133. """
  134. Override to generate oneOf schema for serializers that support both
  135. single object and array input (NetBoxModelViewSet POST operations).
  136. Refs: #20638
  137. """
  138. # Get the standard schema first
  139. schema, required = super()._get_request_for_media_type(serializer, direction)
  140. # If this serializer supports arrays (marked in get_request_serializer),
  141. # wrap the schema in oneOf to allow single object OR array
  142. if (
  143. direction == 'request' and
  144. schema is not None and
  145. getattr(self.view, 'action', None) == 'create' and
  146. viewset_handles_bulk_create(self.view)
  147. ):
  148. return {
  149. 'oneOf': [
  150. schema, # Single object
  151. {
  152. 'type': 'array',
  153. 'items': schema, # Array of objects
  154. }
  155. ]
  156. }, required
  157. return schema, required
  158. def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str:
  159. name = super()._get_serializer_name(serializer, direction, bypass_extensions)
  160. # If this serializer is nested, prepend its name with "Brief"
  161. if getattr(serializer, 'nested', False):
  162. name = f'Brief{name}'
  163. return name
  164. def get_serializer_ref_name(self, serializer):
  165. # from drf-yasg.utils
  166. """Get serializer's ref_name
  167. :param serializer: Serializer instance
  168. :return: Serializer's ``ref_name`` or ``None`` for inline serializer
  169. :rtype: str or None
  170. """
  171. serializer_meta = getattr(serializer, 'Meta', None)
  172. serializer_name = type(serializer).__name__
  173. if hasattr(serializer_meta, 'ref_name'):
  174. ref_name = serializer_meta.ref_name
  175. else:
  176. ref_name = serializer_name
  177. if ref_name.endswith('Serializer'):
  178. ref_name = ref_name[: -len('Serializer')]
  179. return ref_name
  180. def get_writable_class(self, serializer):
  181. properties = {}
  182. fields = {} if hasattr(serializer, 'child') else serializer.fields
  183. remove_fields = []
  184. # If you get a failure here for "AttributeError: 'cached_property' object has no attribute 'items'"
  185. # it is probably because you are using a viewsets.ViewSet for the API View and are defining a
  186. # serializer_class. You will also need to define a get_serializer() method like for GenericAPIView.
  187. for child_name, child in fields.items():
  188. # read_only fields don't need to be in writable (write only) serializers
  189. if 'read_only' in dir(child) and child.read_only:
  190. remove_fields.append(child_name)
  191. if isinstance(child, (ChoiceField, WritableNestedSerializer)):
  192. properties[child_name] = None
  193. if not properties:
  194. return None
  195. if type(serializer) not in self.writable_serializers:
  196. writable_name = 'Writable' + type(serializer).__name__
  197. meta_class = getattr(type(serializer), 'Meta', None)
  198. if meta_class:
  199. ref_name = 'Writable' + self.get_serializer_ref_name(serializer)
  200. # remove read_only fields from write-only serializers
  201. fields = list(meta_class.fields)
  202. for field in remove_fields:
  203. fields.remove(field)
  204. writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name, 'fields': fields})
  205. properties['Meta'] = writable_meta
  206. self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)
  207. writable_class = self.writable_serializers[type(serializer)]
  208. return writable_class
  209. def get_filter_backends(self):
  210. # bulk operations don't have filter params
  211. if self.is_bulk_action:
  212. return []
  213. return super().get_filter_backends()
  214. def _get_paginator(self):
  215. # bulk operations don't have pagination
  216. if self.is_bulk_action:
  217. return None
  218. return super()._get_paginator()
  219. def _get_request_body(self, direction='request'):
  220. # bulk delete should specify input
  221. if (not self.is_bulk_action) or (self.method != 'DELETE'):
  222. return super()._get_request_body(direction)
  223. # rest from drf_spectacular.openapi.AutoSchema._get_request_body
  224. # but remove the unsafe method check
  225. request_serializer = self.get_request_serializer()
  226. if isinstance(request_serializer, dict):
  227. content = []
  228. request_body_required = True
  229. for media_type, serializer in request_serializer.items():
  230. schema, partial_request_body_required = self._get_request_for_media_type(serializer, direction)
  231. examples = self._get_examples(serializer, direction, media_type)
  232. if schema is None:
  233. continue
  234. content.append((media_type, schema, examples))
  235. request_body_required &= partial_request_body_required
  236. else:
  237. schema, request_body_required = self._get_request_for_media_type(request_serializer, direction)
  238. if schema is None:
  239. return None
  240. content = [
  241. (media_type, schema, self._get_examples(request_serializer, direction, media_type))
  242. for media_type in self.map_parsers()
  243. ]
  244. request_body = {
  245. 'content': {
  246. media_type: build_media_type_object(schema, examples) for media_type, schema, examples in content
  247. }
  248. }
  249. if request_body_required:
  250. request_body['required'] = request_body_required
  251. return request_body
  252. def get_description(self):
  253. """
  254. Return a string description for the ViewSet.
  255. """
  256. # If a docstring is provided, use it.
  257. if self.view.__doc__:
  258. return get_doc(self.view.__class__)
  259. # When the action method is decorated with @action, use the docstring of the method.
  260. action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None)
  261. if action_or_method and action_or_method.__doc__:
  262. return get_doc(action_or_method)
  263. # Else, generate a description from the class name.
  264. return self._generate_description()
  265. def _generate_description(self):
  266. """
  267. Generate a docstring for the method. It also takes into account whether the method is for list or detail.
  268. """
  269. model_name = self.view.queryset.model._meta.verbose_name
  270. # Determine if the method is for list or detail.
  271. if '{id}' in self.path:
  272. return f"{self.method.capitalize()} a {model_name} object."
  273. return f"{self.method.capitalize()} a list of {model_name} objects."
  274. class FixSerializedPKRelatedField(OpenApiSerializerFieldExtension):
  275. target_class = 'netbox.api.fields.SerializedPKRelatedField'
  276. def map_serializer_field(self, auto_schema, direction):
  277. if direction == "response":
  278. component = auto_schema.resolve_serializer(self.target.serializer, direction)
  279. return component.ref if component else None
  280. return build_basic_type(OpenApiTypes.INT)
  281. class FixIntegerRangeSerializerSchema(OpenApiSerializerExtension):
  282. target_class = 'netbox.api.fields.IntegerRangeSerializer'
  283. match_subclasses = True
  284. def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
  285. # One range = two integers; many=True will wrap this in an outer array
  286. return {
  287. 'type': 'array',
  288. 'items': {
  289. 'type': 'integer',
  290. },
  291. 'minItems': 2,
  292. 'maxItems': 2,
  293. 'example': [10, 20],
  294. }
  295. # Nested models can be passed by ID in requests
  296. # The logic for this is handled in `BaseModelSerializer.to_internal_value`
  297. class FixWritableNestedSerializerAllowPK(OpenApiSerializerFieldExtension):
  298. target_class = 'netbox.api.serializers.BaseModelSerializer'
  299. match_subclasses = True
  300. def map_serializer_field(self, auto_schema, direction):
  301. schema = auto_schema._map_serializer_field(self.target, direction, bypass_extensions=True)
  302. if schema is None:
  303. return schema
  304. if direction == 'request' and self.target.nested:
  305. return {
  306. 'oneOf': [
  307. build_basic_type(OpenApiTypes.INT),
  308. schema,
  309. ]
  310. }
  311. return schema