custom_inspectors.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from django.contrib.postgres.fields import JSONField
  2. from drf_yasg import openapi
  3. from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector, FilterInspector, SwaggerAutoSchema
  4. from drf_yasg.utils import get_serializer_ref_name
  5. from rest_framework.fields import ChoiceField
  6. from rest_framework.relations import ManyRelatedField
  7. from taggit_serializer.serializers import TagListSerializerField
  8. from dcim.api.serializers import InterfaceSerializer as DeviceInterfaceSerializer
  9. from extras.api.customfields import CustomFieldsSerializer
  10. from utilities.api import ChoiceField, SerializedPKRelatedField, WritableNestedSerializer
  11. from virtualization.api.serializers import InterfaceSerializer as VirtualMachineInterfaceSerializer
  12. # this might be ugly, but it limits drf_yasg-specific code to this file
  13. DeviceInterfaceSerializer.Meta.ref_name = 'DeviceInterface'
  14. VirtualMachineInterfaceSerializer.Meta.ref_name = 'VirtualMachineInterface'
  15. class NetBoxSwaggerAutoSchema(SwaggerAutoSchema):
  16. writable_serializers = {}
  17. def get_request_serializer(self):
  18. serializer = super().get_request_serializer()
  19. if serializer is not None and self.method in self.implicit_body_methods:
  20. properties = {}
  21. for child_name, child in serializer.fields.items():
  22. if isinstance(child, (ChoiceField, WritableNestedSerializer)):
  23. properties[child_name] = None
  24. elif isinstance(child, ManyRelatedField) and isinstance(child.child_relation, SerializedPKRelatedField):
  25. properties[child_name] = None
  26. if properties:
  27. if type(serializer) not in self.writable_serializers:
  28. writable_name = 'Writable' + type(serializer).__name__
  29. meta_class = getattr(type(serializer), 'Meta', None)
  30. if meta_class:
  31. ref_name = 'Writable' + get_serializer_ref_name(serializer)
  32. writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name})
  33. properties['Meta'] = writable_meta
  34. self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)
  35. writable_class = self.writable_serializers[type(serializer)]
  36. serializer = writable_class()
  37. return serializer
  38. class SerializedPKRelatedFieldInspector(FieldInspector):
  39. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  40. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  41. if isinstance(field, SerializedPKRelatedField):
  42. return self.probe_field_inspectors(field.serializer(), ChildSwaggerType, use_references)
  43. return NotHandled
  44. class TagListFieldInspector(FieldInspector):
  45. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  46. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  47. if isinstance(field, TagListSerializerField):
  48. child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
  49. return SwaggerType(
  50. type=openapi.TYPE_ARRAY,
  51. items=child_schema,
  52. )
  53. return NotHandled
  54. class CustomChoiceFieldInspector(FieldInspector):
  55. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  56. # this returns a callable which extracts title, description and other stuff
  57. # https://drf-yasg.readthedocs.io/en/stable/_modules/drf_yasg/inspectors/base.html#FieldInspector._get_partial_types
  58. SwaggerType, _ = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  59. if isinstance(field, ChoiceField):
  60. choices = field._choices
  61. choice_value = list(choices.keys())
  62. choice_label = list(choices.values())
  63. value_schema = openapi.Schema(type=openapi.TYPE_STRING, enum=choice_value)
  64. if set([None] + choice_value) == {None, True, False}:
  65. # DeviceType.subdevice_role, Device.face and InterfaceConnection.connection_status all need to be
  66. # differentiated since they each have subtly different values in their choice keys.
  67. # - subdevice_role and connection_status are booleans, although subdevice_role includes None
  68. # - face is an integer set {0, 1} which is easily confused with {False, True}
  69. schema_type = openapi.TYPE_STRING
  70. if all(type(x) == bool for x in [c for c in choice_value if c is not None]):
  71. schema_type = openapi.TYPE_BOOLEAN
  72. value_schema = openapi.Schema(type=schema_type, enum=choice_value)
  73. value_schema['x-nullable'] = True
  74. if isinstance(choice_value[0], int):
  75. # Change value_schema for IPAddressFamilyChoices, RackWidthChoices
  76. value_schema = openapi.Schema(type=openapi.TYPE_INTEGER, enum=choice_value)
  77. schema = SwaggerType(type=openapi.TYPE_OBJECT, required=["label", "value"], properties={
  78. "label": openapi.Schema(type=openapi.TYPE_STRING, enum=choice_label),
  79. "value": value_schema
  80. })
  81. return schema
  82. elif isinstance(field, CustomFieldsSerializer):
  83. schema = SwaggerType(type=openapi.TYPE_OBJECT)
  84. return schema
  85. return NotHandled
  86. class NullableBooleanFieldInspector(FieldInspector):
  87. def process_result(self, result, method_name, obj, **kwargs):
  88. if isinstance(result, openapi.Schema) and isinstance(obj, ChoiceField) and result.type == 'boolean':
  89. keys = obj.choices.keys()
  90. if set(keys) == {None, True, False}:
  91. result['x-nullable'] = True
  92. result.type = 'boolean'
  93. return result
  94. class JSONFieldInspector(FieldInspector):
  95. """Required because by default, Swagger sees a JSONField as a string and not dict
  96. """
  97. def process_result(self, result, method_name, obj, **kwargs):
  98. if isinstance(result, openapi.Schema) and isinstance(obj, JSONField):
  99. result.type = 'dict'
  100. return result
  101. class IdInFilterInspector(FilterInspector):
  102. def process_result(self, result, method_name, obj, **kwargs):
  103. if isinstance(result, list):
  104. params = [p for p in result if isinstance(p, openapi.Parameter) and p.name == 'id__in']
  105. for p in params:
  106. p.type = 'string'
  107. return result
  108. class NullablePaginatorInspector(PaginatorInspector):
  109. def process_result(self, result, method_name, obj, **kwargs):
  110. if method_name == 'get_paginated_response' and isinstance(result, openapi.Schema):
  111. next = result.properties['next']
  112. if isinstance(next, openapi.Schema):
  113. next['x-nullable'] = True
  114. previous = result.properties['previous']
  115. if isinstance(previous, openapi.Schema):
  116. previous['x-nullable'] = True
  117. return result