schema.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import re
  2. import typing
  3. from collections import OrderedDict
  4. from drf_spectacular.extensions import OpenApiSerializerFieldExtension, OpenApiSerializerExtension, _SchemaType
  5. from drf_spectacular.openapi import AutoSchema
  6. from drf_spectacular.plumbing import (
  7. build_basic_type, build_choice_field, build_media_type_object, build_object_type, get_doc,
  8. )
  9. from drf_spectacular.types import OpenApiTypes
  10. from drf_spectacular.utils import Direction
  11. from netbox.api.fields import ChoiceField
  12. from netbox.api.serializers import WritableNestedSerializer
  13. # see netbox.api.routers.NetBoxRouter
  14. BULK_ACTIONS = ("bulk_destroy", "bulk_partial_update", "bulk_update")
  15. WRITABLE_ACTIONS = ("PATCH", "POST", "PUT")
  16. class FixTimeZoneSerializerField(OpenApiSerializerFieldExtension):
  17. target_class = 'timezone_field.rest_framework.TimeZoneSerializerField'
  18. def map_serializer_field(self, auto_schema, direction):
  19. return build_basic_type(OpenApiTypes.STR)
  20. class ChoiceFieldFix(OpenApiSerializerFieldExtension):
  21. target_class = 'netbox.api.fields.ChoiceField'
  22. def map_serializer_field(self, auto_schema, direction):
  23. build_cf = build_choice_field(self.target)
  24. if direction == 'request':
  25. return build_cf
  26. elif direction == "response":
  27. value = build_cf
  28. label = {
  29. **build_basic_type(OpenApiTypes.STR),
  30. "enum": list(OrderedDict.fromkeys(self.target.choices.values()))
  31. }
  32. return build_object_type(
  33. properties={
  34. "value": value,
  35. "label": label
  36. }
  37. )
  38. class NetBoxAutoSchema(AutoSchema):
  39. """
  40. Overrides to drf_spectacular.openapi.AutoSchema to fix following issues:
  41. 1. bulk serializers cause operation_id conflicts with non-bulk ones
  42. 2. bulk operations should specify a list
  43. 3. bulk operations don't have filter params
  44. 4. bulk operations don't have pagination
  45. 5. bulk delete should specify input
  46. """
  47. writable_serializers = {}
  48. @property
  49. def is_bulk_action(self):
  50. if hasattr(self.view, "action") and self.view.action in BULK_ACTIONS:
  51. return True
  52. else:
  53. return False
  54. def get_operation_id(self):
  55. """
  56. bulk serializers cause operation_id conflicts with non-bulk ones
  57. bulk operations cause id conflicts in spectacular resulting in numerous:
  58. Warning: operationId "xxx" has collisions [xxx]. "resolving with numeral suffixes"
  59. code is modified from drf_spectacular.openapi.AutoSchema.get_operation_id
  60. """
  61. if self.is_bulk_action:
  62. tokenized_path = self._tokenize_path()
  63. # replace dashes as they can be problematic later in code generation
  64. tokenized_path = [t.replace('-', '_') for t in tokenized_path]
  65. if self.method == 'GET' and self._is_list_view():
  66. # this shouldn't happen, but keeping it here to follow base code
  67. action = 'list'
  68. else:
  69. # action = self.method_mapping[self.method.lower()]
  70. # use bulk name so partial_update -> bulk_partial_update
  71. action = self.view.action.lower()
  72. if not tokenized_path:
  73. tokenized_path.append('root')
  74. if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
  75. tokenized_path.append('formatted')
  76. return '_'.join(tokenized_path + [action])
  77. # if not bulk - just return normal id
  78. return super().get_operation_id()
  79. def get_request_serializer(self) -> typing.Any:
  80. # bulk operations should specify a list
  81. serializer = super().get_request_serializer()
  82. if self.is_bulk_action:
  83. return type(serializer)(many=True)
  84. # handle mapping for Writable serializers - adapted from dansheps original code
  85. # for drf-yasg
  86. if serializer is not None and self.method in WRITABLE_ACTIONS:
  87. writable_class = self.get_writable_class(serializer)
  88. if writable_class is not None:
  89. if hasattr(serializer, "child"):
  90. child_serializer = self.get_writable_class(serializer.child)
  91. serializer = writable_class(context=serializer.context, child=child_serializer)
  92. else:
  93. serializer = writable_class(context=serializer.context)
  94. return serializer
  95. def get_response_serializers(self) -> typing.Any:
  96. # bulk operations should specify a list
  97. response_serializers = super().get_response_serializers()
  98. if self.is_bulk_action:
  99. return type(response_serializers)(many=True)
  100. return response_serializers
  101. def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str:
  102. name = super()._get_serializer_name(serializer, direction, bypass_extensions)
  103. # If this serializer is nested, prepend its name with "Brief"
  104. if getattr(serializer, 'nested', False):
  105. name = f'Brief{name}'
  106. return name
  107. def get_serializer_ref_name(self, serializer):
  108. # from drf-yasg.utils
  109. """Get serializer's ref_name
  110. :param serializer: Serializer instance
  111. :return: Serializer's ``ref_name`` or ``None`` for inline serializer
  112. :rtype: str or None
  113. """
  114. serializer_meta = getattr(serializer, 'Meta', None)
  115. serializer_name = type(serializer).__name__
  116. if hasattr(serializer_meta, 'ref_name'):
  117. ref_name = serializer_meta.ref_name
  118. else:
  119. ref_name = serializer_name
  120. if ref_name.endswith('Serializer'):
  121. ref_name = ref_name[: -len('Serializer')]
  122. return ref_name
  123. def get_writable_class(self, serializer):
  124. properties = {}
  125. fields = {} if hasattr(serializer, 'child') else serializer.fields
  126. remove_fields = []
  127. # If you get a failure here for "AttributeError: 'cached_property' object has no attribute 'items'"
  128. # it is probably because you are using a viewsets.ViewSet for the API View and are defining a
  129. # serializer_class. You will also need to define a get_serializer() method like for GenericAPIView.
  130. for child_name, child in fields.items():
  131. # read_only fields don't need to be in writable (write only) serializers
  132. if 'read_only' in dir(child) and child.read_only:
  133. remove_fields.append(child_name)
  134. if isinstance(child, (ChoiceField, WritableNestedSerializer)):
  135. properties[child_name] = None
  136. if not properties:
  137. return None
  138. if type(serializer) not in self.writable_serializers:
  139. writable_name = 'Writable' + type(serializer).__name__
  140. meta_class = getattr(type(serializer), 'Meta', None)
  141. if meta_class:
  142. ref_name = 'Writable' + self.get_serializer_ref_name(serializer)
  143. # remove read_only fields from write-only serializers
  144. fields = list(meta_class.fields)
  145. for field in remove_fields:
  146. fields.remove(field)
  147. writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name, 'fields': fields})
  148. properties['Meta'] = writable_meta
  149. self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)
  150. writable_class = self.writable_serializers[type(serializer)]
  151. return writable_class
  152. def get_filter_backends(self):
  153. # bulk operations don't have filter params
  154. if self.is_bulk_action:
  155. return []
  156. return super().get_filter_backends()
  157. def _get_paginator(self):
  158. # bulk operations don't have pagination
  159. if self.is_bulk_action:
  160. return None
  161. return super()._get_paginator()
  162. def _get_request_body(self, direction='request'):
  163. # bulk delete should specify input
  164. if (not self.is_bulk_action) or (self.method != 'DELETE'):
  165. return super()._get_request_body(direction)
  166. # rest from drf_spectacular.openapi.AutoSchema._get_request_body
  167. # but remove the unsafe method check
  168. request_serializer = self.get_request_serializer()
  169. if isinstance(request_serializer, dict):
  170. content = []
  171. request_body_required = True
  172. for media_type, serializer in request_serializer.items():
  173. schema, partial_request_body_required = self._get_request_for_media_type(serializer, direction)
  174. examples = self._get_examples(serializer, direction, media_type)
  175. if schema is None:
  176. continue
  177. content.append((media_type, schema, examples))
  178. request_body_required &= partial_request_body_required
  179. else:
  180. schema, request_body_required = self._get_request_for_media_type(request_serializer, direction)
  181. if schema is None:
  182. return None
  183. content = [
  184. (media_type, schema, self._get_examples(request_serializer, direction, media_type))
  185. for media_type in self.map_parsers()
  186. ]
  187. request_body = {
  188. 'content': {
  189. media_type: build_media_type_object(schema, examples) for media_type, schema, examples in content
  190. }
  191. }
  192. if request_body_required:
  193. request_body['required'] = request_body_required
  194. return request_body
  195. def get_description(self):
  196. """
  197. Return a string description for the ViewSet.
  198. """
  199. # If a docstring is provided, use it.
  200. if self.view.__doc__:
  201. return get_doc(self.view.__class__)
  202. # When the action method is decorated with @action, use the docstring of the method.
  203. action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None)
  204. if action_or_method and action_or_method.__doc__:
  205. return get_doc(action_or_method)
  206. # Else, generate a description from the class name.
  207. return self._generate_description()
  208. def _generate_description(self):
  209. """
  210. Generate a docstring for the method. It also takes into account whether the method is for list or detail.
  211. """
  212. model_name = self.view.queryset.model._meta.verbose_name
  213. # Determine if the method is for list or detail.
  214. if '{id}' in self.path:
  215. return f"{self.method.capitalize()} a {model_name} object."
  216. return f"{self.method.capitalize()} a list of {model_name} objects."
  217. class FixSerializedPKRelatedField(OpenApiSerializerFieldExtension):
  218. target_class = 'netbox.api.fields.SerializedPKRelatedField'
  219. def map_serializer_field(self, auto_schema, direction):
  220. if direction == "response":
  221. component = auto_schema.resolve_serializer(self.target.serializer, direction)
  222. return component.ref if component else None
  223. else:
  224. return build_basic_type(OpenApiTypes.INT)
  225. class FixIntegerRangeSerializerSchema(OpenApiSerializerExtension):
  226. target_class = 'netbox.api.fields.IntegerRangeSerializer'
  227. def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
  228. return {
  229. 'type': 'array',
  230. 'items': {
  231. 'type': 'array',
  232. 'items': {
  233. 'type': 'integer',
  234. },
  235. 'minItems': 2,
  236. 'maxItems': 2,
  237. },
  238. }
  239. # Nested models can be passed by ID in requests
  240. # The logic for this is handled in `BaseModelSerializer.to_internal_value`
  241. class FixWritableNestedSerializerAllowPK(OpenApiSerializerFieldExtension):
  242. target_class = 'netbox.api.serializers.BaseModelSerializer'
  243. match_subclasses = True
  244. def map_serializer_field(self, auto_schema, direction):
  245. schema = auto_schema._map_serializer_field(self.target, direction, bypass_extensions=True)
  246. if schema is None:
  247. return schema
  248. if direction == 'request' and self.target.nested:
  249. return {
  250. 'oneOf': [
  251. build_basic_type(OpenApiTypes.INT),
  252. schema,
  253. ]
  254. }
  255. return schema