Arthur 2 лет назад
Родитель
Сommit
7779e87ff3

+ 0 - 1
netbox/dcim/graphql/types.py

@@ -6,7 +6,6 @@ from extras.graphql.mixins import (
     ChangelogMixin, ConfigContextMixin, ContactsMixin, CustomFieldsMixin, ImageAttachmentsMixin, TagsMixin,
 )
 from ipam.graphql.mixins import IPAddressesMixin, VLANGroupsMixin
-from netbox.graphql.scalars import BigInt
 from netbox.graphql.types import BaseObjectType, OrganizationalObjectType, NetBoxObjectType
 from .filters import *
 from .mixins import CabledObjectMixin, PathEndpointMixin

+ 0 - 1
netbox/ipam/graphql/types.py

@@ -2,7 +2,6 @@ import strawberry
 import strawberry_django
 
 from ipam import models
-from netbox.graphql.scalars import BigInt
 from netbox.graphql.types import BaseObjectType, OrganizationalObjectType, NetBoxObjectType
 from .filters import *
 

+ 0 - 69
netbox/netbox/graphql/__init__.py

@@ -1,69 +0,0 @@
-import graphene
-from dcim.fields import MACAddressField, WWNField
-from django.db import models
-from graphene import Dynamic
-from graphene_django.converter import convert_django_field, get_django_field_description
-from graphene_django.fields import DjangoConnectionField
-from ipam.fields import IPAddressField, IPNetworkField
-from taggit.managers import TaggableManager
-
-from .fields import ObjectListField
-
-
-@convert_django_field.register(TaggableManager)
-def convert_field_to_tags_list(field, registry=None):
-    """
-    Register conversion handler for django-taggit's TaggableManager
-    """
-    return graphene.List(graphene.String)
-
-
-@convert_django_field.register(IPAddressField)
-@convert_django_field.register(IPNetworkField)
-@convert_django_field.register(MACAddressField)
-@convert_django_field.register(WWNField)
-def convert_field_to_string(field, registry=None):
-    # TODO: Update to use get_django_field_description under django_graphene v3.0
-    return graphene.String(description=field.help_text, required=not field.null)
-
-
-@convert_django_field.register(models.ManyToManyField)
-@convert_django_field.register(models.ManyToManyRel)
-@convert_django_field.register(models.ManyToOneRel)
-def convert_field_to_list_or_connection(field, registry=None):
-    """
-    From graphene_django.converter.py we need to monkey-patch this to return
-    our ObjectListField with filtering support instead of DjangoListField
-    """
-    model = field.related_model
-
-    def dynamic_type():
-        _type = registry.get_type_for_model(model)
-        if not _type:
-            return
-
-        if isinstance(field, models.ManyToManyField):
-            description = get_django_field_description(field)
-        else:
-            description = get_django_field_description(field.field)
-
-        # If there is a connection, we should transform the field
-        # into a DjangoConnectionField
-        if _type._meta.connection:
-            # Use a DjangoFilterConnectionField if there are
-            # defined filter_fields or a filterset_class in the
-            # DjangoObjectType Meta
-            if _type._meta.filter_fields or _type._meta.filterset_class:
-                from .filter.fields import DjangoFilterConnectionField
-
-                return DjangoFilterConnectionField(_type, required=True, description=description)
-
-            return DjangoConnectionField(_type, required=True, description=description)
-
-        return ObjectListField(
-            _type,
-            required=True,  # A Set is always returned, never None.
-            description=description,
-        )
-
-    return Dynamic(dynamic_type)

+ 0 - 70
netbox/netbox/graphql/fields.py

@@ -1,70 +0,0 @@
-from functools import partial
-
-import graphene
-from graphene_django import DjangoListField
-from .utils import get_graphene_type
-
-__all__ = (
-    'ObjectField',
-    'ObjectListField',
-)
-
-
-class ObjectField(graphene.Field):
-    """
-    Retrieve a single object, identified by its numeric ID.
-    """
-    def __init__(self, *args, **kwargs):
-
-        if 'id' not in kwargs:
-            kwargs['id'] = graphene.Int(required=True)
-
-        super().__init__(*args, **kwargs)
-
-    @staticmethod
-    def object_resolver(django_object_type, root, info, **args):
-        """
-        Return an object given its numeric ID.
-        """
-        manager = django_object_type._meta.model._default_manager
-        queryset = django_object_type.get_queryset(manager, info)
-
-        return queryset.get(**args)
-
-    def get_resolver(self, parent_resolver):
-        return partial(self.object_resolver, self._type)
-
-
-class ObjectListField(DjangoListField):
-    """
-    Retrieve a list of objects, optionally filtered by one or more FilterSet filters.
-    """
-    def __init__(self, _type, *args, **kwargs):
-        filter_kwargs = {}
-
-        # Get FilterSet kwargs
-        filterset_class = getattr(_type._meta, 'filterset_class', None)
-        if filterset_class:
-            for filter_name, filter_field in filterset_class.get_filters().items():
-                field_type = get_graphene_type(type(filter_field))
-                filter_kwargs[filter_name] = graphene.Argument(field_type)
-
-        super().__init__(_type, args=filter_kwargs, *args, **kwargs)
-
-    @staticmethod
-    def list_resolver(django_object_type, resolver, default_manager, root, info, **args):
-        queryset = super(ObjectListField, ObjectListField).list_resolver(django_object_type, resolver, default_manager, root, info, **args)
-
-        # if there are no filter params then don't need to filter
-        if not args:
-            return queryset
-
-        filterset_class = django_object_type._meta.filterset_class
-        if filterset_class:
-            filterset = filterset_class(data=args if args else None, queryset=queryset, request=info.context)
-
-            if not filterset.is_valid():
-                return queryset.none()
-            return filterset.qs
-
-        return queryset

+ 0 - 23
netbox/netbox/graphql/scalars.py

@@ -1,23 +0,0 @@
-from graphene import Scalar
-from graphql.language import ast
-from graphene.types.scalars import MAX_INT, MIN_INT
-
-
-class BigInt(Scalar):
-    """
-    Handle any BigInts
-    """
-    @staticmethod
-    def to_float(value):
-        num = int(value)
-        if num > MAX_INT or num < MIN_INT:
-            return float(num)
-        return num
-
-    serialize = to_float
-    parse_value = to_float
-
-    @staticmethod
-    def parse_literal(node):
-        if isinstance(node, ast.IntValue):
-            return BigInt.to_float(node.value)

+ 0 - 252
netbox/utilities/graphql_optimizer.py

@@ -1,252 +0,0 @@
-import functools
-
-from django.core.exceptions import FieldDoesNotExist
-from django.db.models import ForeignKey
-from django.db.models.constants import LOOKUP_SEP
-from django.db.models.fields.reverse_related import ManyToOneRel
-from graphene import InputObjectType
-from graphene.types.generic import GenericScalar
-from graphene.types.resolver import default_resolver
-from graphene_django import DjangoObjectType
-from graphql import GraphQLResolveInfo, GraphQLSchema
-from graphql.execution.execute import get_field_def
-from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode, VariableNode
-from graphql.pyutils import Path
-from graphql.type.definition import GraphQLInterfaceType, GraphQLUnionType
-
-__all__ = (
-    'gql_query_optimizer',
-)
-
-
-def gql_query_optimizer(queryset, info, **options):
-    return QueryOptimizer(info).optimize(queryset)
-
-
-class QueryOptimizer(object):
-    def __init__(self, info, **options):
-        self.root_info = info
-
-    def optimize(self, queryset):
-        info = self.root_info
-        field_def = get_field_def(info.schema, info.parent_type, info.field_nodes[0])
-
-        field_names = self._optimize_gql_selections(
-            self._get_type(field_def),
-            info.field_nodes[0],
-        )
-
-        qs = queryset.prefetch_related(*field_names)
-        return qs
-
-    def _get_type(self, field_def):
-        a_type = field_def.type
-        while hasattr(a_type, "of_type"):
-            a_type = a_type.of_type
-        return a_type
-
-    def _get_graphql_schema(self, schema):
-        if isinstance(schema, GraphQLSchema):
-            return schema
-        else:
-            return schema.graphql_schema
-
-    def _get_possible_types(self, graphql_type):
-        if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
-            graphql_schema = self._get_graphql_schema(self.root_info.schema)
-            return graphql_schema.get_possible_types(graphql_type)
-        else:
-            return (graphql_type,)
-
-    def _get_base_model(self, graphql_types):
-        models = tuple(t.graphene_type._meta.model for t in graphql_types)
-        for model in models:
-            if all(issubclass(m, model) for m in models):
-                return model
-        return None
-
-    def handle_inline_fragment(self, selection, schema, possible_types, field_names):
-        fragment_type_name = selection.type_condition.name.value
-        graphql_schema = self._get_graphql_schema(schema)
-        fragment_type = graphql_schema.get_type(fragment_type_name)
-        fragment_possible_types = self._get_possible_types(fragment_type)
-        for fragment_possible_type in fragment_possible_types:
-            fragment_model = fragment_possible_type.graphene_type._meta.model
-            parent_model = self._get_base_model(possible_types)
-            if not parent_model:
-                continue
-            path_from_parent = fragment_model._meta.get_path_from_parent(parent_model)
-            select_related_name = LOOKUP_SEP.join(p.join_field.name for p in path_from_parent)
-            if not select_related_name:
-                continue
-            sub_field_names = self._optimize_gql_selections(
-                fragment_possible_type,
-                selection,
-            )
-            field_names.append(select_related_name)
-        return
-
-    def handle_fragment_spread(self, field_names, name, field_type):
-        fragment = self.root_info.fragments[name]
-        sub_field_names = self._optimize_gql_selections(
-            field_type,
-            fragment,
-        )
-
-    def _optimize_gql_selections(self, field_type, field_ast):
-        field_names = []
-        selection_set = field_ast.selection_set
-        if not selection_set:
-            return field_names
-        optimized_fields_by_model = {}
-        schema = self.root_info.schema
-        graphql_schema = self._get_graphql_schema(schema)
-        graphql_type = graphql_schema.get_type(field_type.name)
-
-        possible_types = self._get_possible_types(graphql_type)
-        for selection in selection_set.selections:
-            if isinstance(selection, InlineFragmentNode):
-                self.handle_inline_fragment(selection, schema, possible_types, field_names)
-            else:
-                name = selection.name.value
-                if isinstance(selection, FragmentSpreadNode):
-                    self.handle_fragment_spread(field_names, name, field_type)
-                else:
-                    for possible_type in possible_types:
-                        selection_field_def = possible_type.fields.get(name)
-                        if not selection_field_def:
-                            continue
-
-                        graphene_type = possible_type.graphene_type
-                        model = getattr(graphene_type._meta, "model", None)
-                        if model and name not in optimized_fields_by_model:
-                            field_model = optimized_fields_by_model[name] = model
-                            if field_model == model:
-                                self._optimize_field(
-                                    field_names,
-                                    model,
-                                    selection,
-                                    selection_field_def,
-                                    possible_type,
-                                )
-        return field_names
-
-    def _get_field_info(self, field_names, model, selection, field_def):
-        name = None
-        model_field = None
-        name = self._get_name_from_resolver(field_def.resolve)
-        if not name and callable(field_def.resolve) and not isinstance(field_def.resolve, functools.partial):
-            name = selection.name.value
-        if name:
-            model_field = self._get_model_field_from_name(model, name)
-
-        return (name, model_field)
-
-    def _optimize_field(self, field_names, model, selection, field_def, parent_type):
-        name, model_field = self._get_field_info(field_names, model, selection, field_def)
-        if model_field:
-            self._optimize_field_by_name(field_names, model, selection, field_def, name, model_field)
-
-        return
-
-    def _optimize_field_by_name(self, field_names, model, selection, field_def, name, model_field):
-        if model_field.many_to_one or model_field.one_to_one:
-            sub_field_names = self._optimize_gql_selections(
-                self._get_type(field_def),
-                selection,
-            )
-            if name not in field_names:
-                field_names.append(name)
-
-            for field in sub_field_names:
-                prefetch_key = f"{name}__{field}"
-                if prefetch_key not in field_names:
-                    field_names.append(prefetch_key)
-
-        if model_field.one_to_many or model_field.many_to_many:
-            sub_field_names = self._optimize_gql_selections(
-                self._get_type(field_def),
-                selection,
-            )
-
-            if isinstance(model_field, ManyToOneRel):
-                sub_field_names.append(model_field.field.name)
-
-            field_names.append(name)
-            for field in sub_field_names:
-                prefetch_key = f"{name}__{field}"
-                if prefetch_key not in field_names:
-                    field_names.append(prefetch_key)
-
-        return
-
-    def _get_optimization_hints(self, resolver):
-        return getattr(resolver, "optimization_hints", None)
-
-    def _get_value(self, info, value):
-        if isinstance(value, VariableNode):
-            var_name = value.name.value
-            value = info.variable_values.get(var_name)
-            return value
-        elif isinstance(value, InputObjectType):
-            return value.__dict__
-        else:
-            return GenericScalar.parse_literal(value)
-
-    def _get_name_from_resolver(self, resolver):
-        optimization_hints = self._get_optimization_hints(resolver)
-        if optimization_hints:
-            name_fn = optimization_hints.model_field
-            if name_fn:
-                return name_fn()
-        if self._is_resolver_for_id_field(resolver):
-            return "id"
-        elif isinstance(resolver, functools.partial):
-            resolver_fn = resolver
-            if resolver_fn.func != default_resolver:
-                # Some resolvers have the partial function as the second
-                # argument.
-                for arg in resolver_fn.args:
-                    if isinstance(arg, (str, functools.partial)):
-                        break
-                else:
-                    # No suitable instances found, default to first arg
-                    arg = resolver_fn.args[0]
-                resolver_fn = arg
-            if isinstance(resolver_fn, functools.partial) and resolver_fn.func == default_resolver:
-                return resolver_fn.args[0]
-            if self._is_resolver_for_id_field(resolver_fn):
-                return "id"
-            return resolver_fn
-
-    def _is_resolver_for_id_field(self, resolver):
-        resolve_id = DjangoObjectType.resolve_id
-        return resolver == resolve_id
-
-    def _get_model_field_from_name(self, model, name):
-        try:
-            return model._meta.get_field(name)
-        except FieldDoesNotExist:
-            descriptor = model.__dict__.get(name)
-            if not descriptor:
-                return None
-            return getattr(descriptor, "rel", None) or getattr(descriptor, "related", None)  # Django < 1.9
-
-    def _is_foreign_key_id(self, model_field, name):
-        return isinstance(model_field, ForeignKey) and model_field.name != name and model_field.get_attname() == name
-
-    def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
-        return GraphQLResolveInfo(
-            field_name,
-            field_asts,
-            return_type,
-            parent_type,
-            Path(None, 0, None),
-            schema=self.root_info.schema,
-            fragments=self.root_info.fragments,
-            root_value=self.root_info.root_value,
-            operation=self.root_info.operation,
-            variable_values=self.root_info.variable_values,
-            context=self.root_info.context,
-            is_awaitable=self.root_info.is_awaitable,
-        )

+ 4 - 1
netbox/utilities/testing/api.py

@@ -446,7 +446,10 @@ class APIViewTestCases:
 
             # Compile list of fields to include
             fields_string = ''
-            for field_name, field in type_class._meta.fields.items():
+
+            for field_name, field in type_class.__dataclass_fields__.items():
+                # for field_name, field in type_class._meta.fields.items():
+                print(f"field_name: {field_name} field: {field}")
                 is_string_array = False
                 if type(field.type) is GQLList:
                     if field.type.of_type is GQLString: