custom_inspectors.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from django.contrib.postgres.fields import JSONField
  2. from drf_yasg import openapi
  3. from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector, SwaggerAutoSchema
  4. from drf_yasg.utils import get_serializer_ref_name
  5. from rest_framework.fields import ChoiceField
  6. from rest_framework.relations import ManyRelatedField
  7. from extras.api.customfields import CustomFieldsDataField
  8. from netbox.api import ChoiceField, SerializedPKRelatedField, WritableNestedSerializer
  9. class NetBoxSwaggerAutoSchema(SwaggerAutoSchema):
  10. writable_serializers = {}
  11. def get_operation_id(self, operation_keys=None):
  12. operation_keys = operation_keys or self.operation_keys
  13. operation_id = self.overrides.get('operation_id', '')
  14. if not operation_id:
  15. # Overwrite the action for bulk update/bulk delete views to ensure they get an operation ID that's
  16. # unique from their single-object counterparts (see #3436)
  17. if operation_keys[-1] in ('delete', 'partial_update', 'update') and not self.view.detail:
  18. operation_keys[-1] = f'bulk_{operation_keys[-1]}'
  19. operation_id = '_'.join(operation_keys)
  20. return operation_id
  21. def get_request_serializer(self):
  22. serializer = super().get_request_serializer()
  23. if serializer is not None and self.method in self.implicit_body_methods:
  24. writable_class = self.get_writable_class(serializer)
  25. if writable_class is not None:
  26. if hasattr(serializer, 'child'):
  27. child_serializer = self.get_writable_class(serializer.child)
  28. serializer = writable_class(child=child_serializer)
  29. else:
  30. serializer = writable_class()
  31. return serializer
  32. def get_writable_class(self, serializer):
  33. properties = {}
  34. fields = {} if hasattr(serializer, 'child') else serializer.fields
  35. for child_name, child in fields.items():
  36. if isinstance(child, (ChoiceField, WritableNestedSerializer)):
  37. properties[child_name] = None
  38. elif isinstance(child, ManyRelatedField) and isinstance(child.child_relation, SerializedPKRelatedField):
  39. properties[child_name] = None
  40. if properties:
  41. if type(serializer) not in self.writable_serializers:
  42. writable_name = 'Writable' + type(serializer).__name__
  43. meta_class = getattr(type(serializer), 'Meta', None)
  44. if meta_class:
  45. ref_name = 'Writable' + get_serializer_ref_name(serializer)
  46. writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name})
  47. properties['Meta'] = writable_meta
  48. self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)
  49. writable_class = self.writable_serializers[type(serializer)]
  50. return writable_class
  51. class SerializedPKRelatedFieldInspector(FieldInspector):
  52. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  53. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  54. if isinstance(field, SerializedPKRelatedField):
  55. return self.probe_field_inspectors(field.serializer(), ChildSwaggerType, use_references)
  56. return NotHandled
  57. class ChoiceFieldInspector(FieldInspector):
  58. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  59. # this returns a callable which extracts title, description and other stuff
  60. # https://drf-yasg.readthedocs.io/en/stable/_modules/drf_yasg/inspectors/base.html#FieldInspector._get_partial_types
  61. SwaggerType, _ = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  62. if isinstance(field, ChoiceField):
  63. choices = field._choices
  64. choice_value = list(choices.keys())
  65. choice_label = list(choices.values())
  66. value_schema = openapi.Schema(type=openapi.TYPE_STRING, enum=choice_value)
  67. if set([None] + choice_value) == {None, True, False}:
  68. # DeviceType.subdevice_role and Device.face need to be differentiated since they each have
  69. # subtly different values in their choice keys.
  70. # - subdevice_role and connection_status are booleans, although subdevice_role includes None
  71. # - face is an integer set {0, 1} which is easily confused with {False, True}
  72. schema_type = openapi.TYPE_STRING
  73. if all(type(x) == bool for x in [c for c in choice_value if c is not None]):
  74. schema_type = openapi.TYPE_BOOLEAN
  75. value_schema = openapi.Schema(type=schema_type, enum=choice_value)
  76. value_schema['x-nullable'] = True
  77. if all(type(x) == int for x in [c for c in choice_value if c is not None]):
  78. # Change value_schema for IPAddressFamilyChoices, RackWidthChoices
  79. value_schema = openapi.Schema(type=openapi.TYPE_INTEGER, enum=choice_value)
  80. schema = SwaggerType(type=openapi.TYPE_OBJECT, required=["label", "value"], properties={
  81. "label": openapi.Schema(type=openapi.TYPE_STRING, enum=choice_label),
  82. "value": value_schema
  83. })
  84. return schema
  85. return NotHandled
  86. class NullableBooleanFieldInspector(FieldInspector):
  87. def process_result(self, result, method_name, obj, **kwargs):
  88. if isinstance(result, openapi.Schema) and isinstance(obj, ChoiceField) and result.type == 'boolean':
  89. keys = obj.choices.keys()
  90. if set(keys) == {None, True, False}:
  91. result['x-nullable'] = True
  92. result.type = 'boolean'
  93. return result
  94. class CustomFieldsDataFieldInspector(FieldInspector):
  95. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  96. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  97. if isinstance(field, CustomFieldsDataField) and swagger_object_type == openapi.Schema:
  98. return SwaggerType(type=openapi.TYPE_OBJECT)
  99. return NotHandled
  100. class JSONFieldInspector(FieldInspector):
  101. """Required because by default, Swagger sees a JSONField as a string and not dict
  102. """
  103. def process_result(self, result, method_name, obj, **kwargs):
  104. if isinstance(result, openapi.Schema) and isinstance(obj, JSONField):
  105. result.type = 'dict'
  106. return result
  107. class NullablePaginatorInspector(PaginatorInspector):
  108. def process_result(self, result, method_name, obj, **kwargs):
  109. if method_name == 'get_paginated_response' and isinstance(result, openapi.Schema):
  110. next = result.properties['next']
  111. if isinstance(next, openapi.Schema):
  112. next['x-nullable'] = True
  113. previous = result.properties['previous']
  114. if isinstance(previous, openapi.Schema):
  115. previous['x-nullable'] = True
  116. return result