|
|
@@ -12,11 +12,12 @@ from django.urls import reverse
|
|
|
from django.utils.translation import gettext_lazy as _
|
|
|
from rest_framework import status
|
|
|
from rest_framework.serializers import Serializer
|
|
|
-from rest_framework.utils import formatting
|
|
|
+from rest_framework.views import get_view_name as drf_get_view_name
|
|
|
|
|
|
+from extras.constants import HTTP_CONTENT_TYPE_JSON
|
|
|
from netbox.api.fields import RelatedObjectCountField
|
|
|
from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound
|
|
|
-from .utils import count_related, dict_to_filter_params, dynamic_import
|
|
|
+from .utils import count_related, dict_to_filter_params, dynamic_import, title
|
|
|
|
|
|
__all__ = (
|
|
|
'get_annotations_for_serializer',
|
|
|
@@ -32,7 +33,7 @@ __all__ = (
|
|
|
|
|
|
def get_serializer_for_model(model, prefix=''):
|
|
|
"""
|
|
|
- Dynamically resolve and return the appropriate serializer for a model.
|
|
|
+ Return the appropriate REST API serializer for the given model.
|
|
|
"""
|
|
|
app_label, model_name = model._meta.label.split('.')
|
|
|
serializer_name = f'{app_label}.api.serializers.{prefix}{model_name}Serializer'
|
|
|
@@ -48,15 +49,12 @@ def get_graphql_type_for_model(model):
|
|
|
"""
|
|
|
Return the GraphQL type class for the given model.
|
|
|
"""
|
|
|
- app_name, model_name = model._meta.label.split('.')
|
|
|
- # Object types for Django's auth models are in the users app
|
|
|
- if app_name == 'auth':
|
|
|
- app_name = 'users'
|
|
|
- class_name = f'{app_name}.graphql.types.{model_name}Type'
|
|
|
+ app_label, model_name = model._meta.label.split('.')
|
|
|
+ class_name = f'{app_label}.graphql.types.{model_name}Type'
|
|
|
try:
|
|
|
return dynamic_import(class_name)
|
|
|
except AttributeError:
|
|
|
- raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_name}.{model_name}")
|
|
|
+ raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_label}.{model_name}")
|
|
|
|
|
|
|
|
|
def is_api_request(request):
|
|
|
@@ -64,30 +62,23 @@ def is_api_request(request):
|
|
|
Return True of the request is being made via the REST API.
|
|
|
"""
|
|
|
api_path = reverse('api-root')
|
|
|
-
|
|
|
- return request.path_info.startswith(api_path) and request.content_type == 'application/json'
|
|
|
+ return request.path_info.startswith(api_path) and request.content_type == HTTP_CONTENT_TYPE_JSON
|
|
|
|
|
|
|
|
|
-def get_view_name(view, suffix=None):
|
|
|
+def get_view_name(view):
|
|
|
"""
|
|
|
- Derive the view name from its associated model, if it has one. Fall back to DRF's built-in `get_view_name`.
|
|
|
+ Derive the view name from its associated model, if it has one. Fall back to DRF's built-in `get_view_name()`.
|
|
|
+ This function is provided to DRF as its VIEW_NAME_FUNCTION.
|
|
|
"""
|
|
|
if hasattr(view, 'queryset'):
|
|
|
- # Determine the model name from the queryset.
|
|
|
- name = view.queryset.model._meta.verbose_name
|
|
|
- name = ' '.join([w[0].upper() + w[1:] for w in name.split()]) # Capitalize each word
|
|
|
-
|
|
|
- else:
|
|
|
- # Replicate DRF's built-in behavior.
|
|
|
- name = view.__class__.__name__
|
|
|
- name = formatting.remove_trailing_string(name, 'View')
|
|
|
- name = formatting.remove_trailing_string(name, 'ViewSet')
|
|
|
- name = formatting.camelcase_to_spaces(name)
|
|
|
-
|
|
|
- if suffix:
|
|
|
- name += ' ' + suffix
|
|
|
-
|
|
|
- return name
|
|
|
+ # Derive the model name from the queryset.
|
|
|
+ name = title(view.queryset.model._meta.verbose_name)
|
|
|
+ if suffix := getattr(view, 'suffix', None):
|
|
|
+ name = f'{name} {suffix}'
|
|
|
+ return name
|
|
|
+
|
|
|
+ # Fall back to DRF's default behavior
|
|
|
+ return drf_get_view_name(view)
|
|
|
|
|
|
|
|
|
def get_prefetches_for_serializer(serializer_class, fields_to_include=None):
|