Sfoglia il codice sorgente

fix(api): Fix schema and field definitions for OpenAPI

Add `get_internal_type()` to custom field classes for Django compatibility,
annotate path parameters and operation IDs for background endpoints, and
provide serializer context on the RQ base viewset to clear schema warnings.

Fixes #20365
Martin Hauser 4 mesi fa
parent
commit
9e75a2f955

+ 1 - 1
netbox/core/api/serializers_/tasks.py

@@ -13,7 +13,7 @@ class BackgroundTaskSerializer(serializers.Serializer):
     url = serializers.HyperlinkedIdentityField(
         view_name='core-api:rqtask-detail',
         lookup_field='id',
-        lookup_url_kwarg='pk'
+        lookup_url_kwarg='id'
     )
     description = serializers.CharField()
     origin = serializers.CharField()

+ 58 - 23
netbox/core/api/views.py

@@ -5,7 +5,7 @@ from django_rq.queues import get_redis_connection
 from django_rq.settings import QUEUES_LIST
 from django_rq.utils import get_statistics
 from drf_spectacular.types import OpenApiTypes
-from drf_spectacular.utils import extend_schema
+from drf_spectacular.utils import OpenApiParameter, extend_schema
 from rest_framework import viewsets
 from rest_framework.decorators import action
 from rest_framework.exceptions import PermissionDenied
@@ -24,6 +24,7 @@ from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
 from netbox.api.metadata import ContentTypeMetadata
 from netbox.api.pagination import LimitOffsetListPagination
 from netbox.api.viewsets import NetBoxModelViewSet, NetBoxReadOnlyModelViewSet
+
 from . import serializers
 
 
@@ -117,29 +118,49 @@ class BaseRQViewSet(viewsets.ViewSet):
     def get_serializer(self, *args, **kwargs):
         """
         Return the serializer instance that should be used for validating and
-        deserializing input, and for serializing output.
+        deserializing input and for serializing output.
         """
         serializer_class = self.get_serializer_class()
         kwargs['context'] = self.get_serializer_context()
         return serializer_class(*args, **kwargs)
 
+    def get_serializer_class(self):
+        """
+        Return the class to use for the serializer.
+        """
+        return self.serializer_class
+
+    def get_serializer_context(self):
+        """
+        Extra context provided to the serializer class.
+        """
+        return {
+            'request': self.request,
+            'format': self.format_kwarg,
+            'view': self,
+        }
+
 
 class BackgroundQueueViewSet(BaseRQViewSet):
     """
     Retrieve a list of RQ Queues.
-    Note: Queue names are not URL safe so not returning a detail view.
+    Note: Queue names are not URL safe, so not returning a detail view.
     """
     serializer_class = serializers.BackgroundQueueSerializer
     lookup_field = 'name'
     lookup_value_regex = r'[\w.@+-]+'
 
     def get_view_name(self):
-        return "Background Queues"
+        return 'Background Queues'
 
     def get_data(self):
-        return get_statistics(run_maintenance_tasks=True)["queues"]
+        return get_statistics(run_maintenance_tasks=True)['queues']
 
-    @extend_schema(responses={200: OpenApiTypes.OBJECT})
+    @extend_schema(
+        operation_id='core_background_queues_retrieve_by_name',
+        parameters=[OpenApiParameter(name='name', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)],
+        responses={200: OpenApiTypes.OBJECT},
+    )
     def retrieve(self, request, name):
         data = self.get_data()
         if not data:
@@ -161,12 +182,17 @@ class BackgroundWorkerViewSet(BaseRQViewSet):
     lookup_field = 'name'
 
     def get_view_name(self):
-        return "Background Workers"
+        return 'Background Workers'
 
     def get_data(self):
         config = QUEUES_LIST[0]
         return Worker.all(get_redis_connection(config['connection_config']))
 
+    @extend_schema(
+        operation_id='core_background_workers_retrieve_by_name',
+        parameters=[OpenApiParameter(name='name', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)],
+        responses={200: OpenApiTypes.OBJECT},
+    )
     def retrieve(self, request, name):
         # all the RQ queues should use the same connection
         config = QUEUES_LIST[0]
@@ -184,9 +210,10 @@ class BackgroundTaskViewSet(BaseRQViewSet):
     Retrieve a list of RQ Tasks.
     """
     serializer_class = serializers.BackgroundTaskSerializer
+    lookup_field = 'id'
 
     def get_view_name(self):
-        return "Background Tasks"
+        return 'Background Tasks'
 
     def get_data(self):
         return get_rq_jobs()
@@ -199,45 +226,53 @@ class BackgroundTaskViewSet(BaseRQViewSet):
 
         return task
 
-    @extend_schema(responses={200: OpenApiTypes.OBJECT})
-    def retrieve(self, request, pk):
+    @extend_schema(
+        operation_id='core_background_tasks_retrieve_by_id',
+        parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)],
+        responses={200: OpenApiTypes.OBJECT},
+    )
+    def retrieve(self, request, id):
         """
         Retrieve the details of the specified RQ Task.
         """
-        task = self.get_task_from_id(pk)
+        task = self.get_task_from_id(id)
         serializer = self.serializer_class(task, context={'request': request})
         return Response(serializer.data)
 
-    @action(methods=["POST"], detail=True)
-    def delete(self, request, pk):
+    @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
+    @action(methods=['POST'], detail=True)
+    def delete(self, request, id):
         """
         Delete the specified RQ Task.
         """
-        delete_rq_job(pk)
+        delete_rq_job(id)
         return HttpResponse(status=200)
 
-    @action(methods=["POST"], detail=True)
-    def requeue(self, request, pk):
+    @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
+    @action(methods=['POST'], detail=True)
+    def requeue(self, request, id):
         """
         Requeues the specified RQ Task.
         """
-        requeue_rq_job(pk)
+        requeue_rq_job(id)
         return HttpResponse(status=200)
 
-    @action(methods=["POST"], detail=True)
-    def enqueue(self, request, pk):
+    @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
+    @action(methods=['POST'], detail=True)
+    def enqueue(self, request, id):
         """
         Enqueues the specified RQ Task.
         """
-        enqueue_rq_job(pk)
+        enqueue_rq_job(id)
         return HttpResponse(status=200)
 
-    @action(methods=["POST"], detail=True)
-    def stop(self, request, pk):
+    @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
+    @action(methods=['POST'], detail=True)
+    def stop(self, request, id):
         """
         Stops the specified RQ Task.
         """
-        stopped_jobs = stop_rq_job(pk)
+        stopped_jobs = stop_rq_job(id)
         if len(stopped_jobs) == 1:
             return HttpResponse(status=200)
         else:

+ 8 - 2
netbox/dcim/fields.py

@@ -26,7 +26,7 @@ class eui64_unix_expanded_uppercase(eui64_unix_expanded):
 #
 
 class MACAddressField(models.Field):
-    description = "PostgreSQL MAC Address field"
+    description = 'PostgreSQL MAC Address field'
 
     def python_type(self):
         return EUI
@@ -34,6 +34,9 @@ class MACAddressField(models.Field):
     def from_db_value(self, value, expression, connection):
         return self.to_python(value)
 
+    def get_internal_type(self):
+        return 'CharField'
+
     def to_python(self, value):
         if value is None:
             return value
@@ -54,7 +57,7 @@ class MACAddressField(models.Field):
 
 
 class WWNField(models.Field):
-    description = "World Wide Name field"
+    description = 'World Wide Name field'
 
     def python_type(self):
         return EUI
@@ -62,6 +65,9 @@ class WWNField(models.Field):
     def from_db_value(self, value, expression, connection):
         return self.to_python(value)
 
+    def get_internal_type(self):
+        return 'CharField'
+
     def to_python(self, value):
         if value is None:
             return value

+ 1 - 0
netbox/extras/api/serializers_/customfields.py

@@ -26,6 +26,7 @@ class CustomFieldChoiceSetSerializer(ChangeLogMessageSerializer, ValidatedModelS
             max_length=2
         )
     )
+    choices_count = serializers.IntegerField(read_only=True)
 
     class Meta:
         model = CustomFieldChoiceSet

+ 6 - 3
netbox/ipam/fields.py

@@ -26,6 +26,9 @@ class BaseIPField(models.Field):
     def from_db_value(self, value, expression, connection):
         return self.to_python(value)
 
+    def get_internal_type(self):
+        return 'CharField'
+
     def to_python(self, value):
         if not value:
             return value
@@ -57,7 +60,7 @@ class IPNetworkField(BaseIPField):
     """
     IP prefix (network and mask)
     """
-    description = "PostgreSQL CIDR field"
+    description = 'PostgreSQL CIDR field'
     default_validators = [validators.prefix_validator]
 
     def db_type(self, connection):
@@ -83,7 +86,7 @@ class IPAddressField(BaseIPField):
     """
     IP address (host address and mask)
     """
-    description = "PostgreSQL INET field"
+    description = 'PostgreSQL INET field'
 
     def db_type(self, connection):
         return 'inet'
@@ -110,7 +113,7 @@ IPAddressField.register_lookup(lookups.Inet)
 
 
 class ASNField(models.BigIntegerField):
-    description = "32-bit ASN field"
+    description = '32-bit ASN field'
     default_validators = [
         MinValueValidator(BGP_ASN_MIN),
         MaxValueValidator(BGP_ASN_MAX),

+ 7 - 5
netbox/ipam/filtersets.py

@@ -354,13 +354,13 @@ class PrefixFilterSet(NetBoxModelFilterSet, ScopedFilterSet, TenancyFilterSet, C
     vlan_group_id = django_filters.ModelMultipleChoiceFilter(
         field_name='vlan__group',
         queryset=VLANGroup.objects.all(),
-        to_field_name="id",
+        to_field_name='id',
         label=_('VLAN Group (ID)'),
     )
     vlan_group = django_filters.ModelMultipleChoiceFilter(
         field_name='vlan__group__slug',
         queryset=VLANGroup.objects.all(),
-        to_field_name="slug",
+        to_field_name='slug',
         label=_('VLAN Group (slug)'),
     )
     vlan_id = django_filters.ModelMultipleChoiceFilter(
@@ -695,12 +695,12 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet, ContactModelFil
         return queryset.filter(q)
 
     def parse_inet_addresses(self, value):
-        '''
+        """
         Parse networks or IP addresses and cast to a format
         acceptable by the Postgres inet type.
 
         Skips invalid values.
-        '''
+        """
         parsed = []
         for addr in value:
             if netaddr.valid_ipv4(addr) or netaddr.valid_ipv6(addr):
@@ -718,7 +718,7 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet, ContactModelFil
         # as argument. If they are all invalid,
         # we return an empty queryset
         value = self.parse_inet_addresses(value)
-        if (len(value) == 0):
+        if len(value) == 0:
             return queryset.none()
 
         try:
@@ -1079,6 +1079,7 @@ class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
     def get_for_virtualmachine(self, queryset, name, value):
         return queryset.get_for_virtualmachine(value)
 
+    @extend_schema_field(OpenApiTypes.INT)
     def filter_interface_id(self, queryset, name, value):
         if value is None:
             return queryset.none()
@@ -1087,6 +1088,7 @@ class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
             Q(interfaces_as_untagged=value)
         ).distinct()
 
+    @extend_schema_field(OpenApiTypes.INT)
     def filter_vminterface_id(self, queryset, name, value):
         if value is None:
             return queryset.none()