| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- import functools
- from django.core.exceptions import FieldDoesNotExist
- from django.db.models import ForeignKey
- from django.db.models.constants import LOOKUP_SEP
- from django.db.models.fields.reverse_related import ManyToOneRel
- from graphene import InputObjectType
- from graphene.types.generic import GenericScalar
- from graphene.types.resolver import default_resolver
- from graphene_django import DjangoObjectType
- from graphql import GraphQLResolveInfo, GraphQLSchema
- from graphql.execution.execute import get_field_def
- from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode, VariableNode
- from graphql.pyutils import Path
- from graphql.type.definition import GraphQLInterfaceType, GraphQLUnionType
- __all__ = (
- 'gql_query_optimizer',
- )
- def gql_query_optimizer(queryset, info, **options):
- return QueryOptimizer(info).optimize(queryset)
- class QueryOptimizer(object):
- def __init__(self, info, **options):
- self.root_info = info
- def optimize(self, queryset):
- info = self.root_info
- field_def = get_field_def(info.schema, info.parent_type, info.field_nodes[0])
- field_names = self._optimize_gql_selections(
- self._get_type(field_def),
- info.field_nodes[0],
- )
- qs = queryset.prefetch_related(*field_names)
- return qs
- def _get_type(self, field_def):
- a_type = field_def.type
- while hasattr(a_type, "of_type"):
- a_type = a_type.of_type
- return a_type
- def _get_graphql_schema(self, schema):
- if isinstance(schema, GraphQLSchema):
- return schema
- else:
- return schema.graphql_schema
- def _get_possible_types(self, graphql_type):
- if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
- graphql_schema = self._get_graphql_schema(self.root_info.schema)
- return graphql_schema.get_possible_types(graphql_type)
- else:
- return (graphql_type,)
- def _get_base_model(self, graphql_types):
- models = tuple(t.graphene_type._meta.model for t in graphql_types)
- for model in models:
- if all(issubclass(m, model) for m in models):
- return model
- return None
- def handle_inline_fragment(self, selection, schema, possible_types, field_names):
- fragment_type_name = selection.type_condition.name.value
- graphql_schema = self._get_graphql_schema(schema)
- fragment_type = graphql_schema.get_type(fragment_type_name)
- fragment_possible_types = self._get_possible_types(fragment_type)
- for fragment_possible_type in fragment_possible_types:
- fragment_model = fragment_possible_type.graphene_type._meta.model
- parent_model = self._get_base_model(possible_types)
- if not parent_model:
- continue
- path_from_parent = fragment_model._meta.get_path_from_parent(parent_model)
- select_related_name = LOOKUP_SEP.join(p.join_field.name for p in path_from_parent)
- if not select_related_name:
- continue
- sub_field_names = self._optimize_gql_selections(
- fragment_possible_type,
- selection,
- )
- field_names.append(select_related_name)
- return
- def handle_fragment_spread(self, field_names, name, field_type):
- fragment = self.root_info.fragments[name]
- sub_field_names = self._optimize_gql_selections(
- field_type,
- fragment,
- )
- def _optimize_gql_selections(self, field_type, field_ast):
- field_names = []
- selection_set = field_ast.selection_set
- if not selection_set:
- return field_names
- optimized_fields_by_model = {}
- schema = self.root_info.schema
- graphql_schema = self._get_graphql_schema(schema)
- graphql_type = graphql_schema.get_type(field_type.name)
- possible_types = self._get_possible_types(graphql_type)
- for selection in selection_set.selections:
- if isinstance(selection, InlineFragmentNode):
- self.handle_inline_fragment(selection, schema, possible_types, field_names)
- else:
- name = selection.name.value
- if isinstance(selection, FragmentSpreadNode):
- self.handle_fragment_spread(field_names, name, field_type)
- else:
- for possible_type in possible_types:
- selection_field_def = possible_type.fields.get(name)
- if not selection_field_def:
- continue
- graphene_type = possible_type.graphene_type
- model = getattr(graphene_type._meta, "model", None)
- if model and name not in optimized_fields_by_model:
- field_model = optimized_fields_by_model[name] = model
- if field_model == model:
- self._optimize_field(
- field_names,
- model,
- selection,
- selection_field_def,
- possible_type,
- )
- return field_names
- def _get_field_info(self, field_names, model, selection, field_def):
- name = None
- model_field = None
- name = self._get_name_from_resolver(field_def.resolve)
- if not name and callable(field_def.resolve) and not isinstance(field_def.resolve, functools.partial):
- name = selection.name.value
- if name:
- model_field = self._get_model_field_from_name(model, name)
- return (name, model_field)
- def _optimize_field(self, field_names, model, selection, field_def, parent_type):
- name, model_field = self._get_field_info(field_names, model, selection, field_def)
- if model_field:
- self._optimize_field_by_name(field_names, model, selection, field_def, name, model_field)
- return
- def _optimize_field_by_name(self, field_names, model, selection, field_def, name, model_field):
- if model_field.many_to_one or model_field.one_to_one:
- sub_field_names = self._optimize_gql_selections(
- self._get_type(field_def),
- selection,
- )
- if name not in field_names:
- field_names.append(name)
- for field in sub_field_names:
- prefetch_key = f"{name}__{field}"
- if prefetch_key not in field_names:
- field_names.append(prefetch_key)
- if model_field.one_to_many or model_field.many_to_many:
- sub_field_names = self._optimize_gql_selections(
- self._get_type(field_def),
- selection,
- )
- if isinstance(model_field, ManyToOneRel):
- sub_field_names.append(model_field.field.name)
- field_names.append(name)
- for field in sub_field_names:
- prefetch_key = f"{name}__{field}"
- if prefetch_key not in field_names:
- field_names.append(prefetch_key)
- return
- def _get_optimization_hints(self, resolver):
- return getattr(resolver, "optimization_hints", None)
- def _get_value(self, info, value):
- if isinstance(value, VariableNode):
- var_name = value.name.value
- value = info.variable_values.get(var_name)
- return value
- elif isinstance(value, InputObjectType):
- return value.__dict__
- else:
- return GenericScalar.parse_literal(value)
- def _get_name_from_resolver(self, resolver):
- optimization_hints = self._get_optimization_hints(resolver)
- if optimization_hints:
- name_fn = optimization_hints.model_field
- if name_fn:
- return name_fn()
- if self._is_resolver_for_id_field(resolver):
- return "id"
- elif isinstance(resolver, functools.partial):
- resolver_fn = resolver
- if resolver_fn.func != default_resolver:
- # Some resolvers have the partial function as the second
- # argument.
- for arg in resolver_fn.args:
- if isinstance(arg, (str, functools.partial)):
- break
- else:
- # No suitable instances found, default to first arg
- arg = resolver_fn.args[0]
- resolver_fn = arg
- if isinstance(resolver_fn, functools.partial) and resolver_fn.func == default_resolver:
- return resolver_fn.args[0]
- if self._is_resolver_for_id_field(resolver_fn):
- return "id"
- return resolver_fn
- def _is_resolver_for_id_field(self, resolver):
- resolve_id = DjangoObjectType.resolve_id
- return resolver == resolve_id
- def _get_model_field_from_name(self, model, name):
- try:
- return model._meta.get_field(name)
- except FieldDoesNotExist:
- descriptor = model.__dict__.get(name)
- if not descriptor:
- return None
- return getattr(descriptor, "rel", None) or getattr(descriptor, "related", None) # Django < 1.9
- def _is_foreign_key_id(self, model_field, name):
- return isinstance(model_field, ForeignKey) and model_field.name != name and model_field.get_attname() == name
- def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
- return GraphQLResolveInfo(
- field_name,
- field_asts,
- return_type,
- parent_type,
- Path(None, 0, None),
- schema=self.root_info.schema,
- fragments=self.root_info.fragments,
- root_value=self.root_info.root_value,
- operation=self.root_info.operation,
- variable_values=self.root_info.variable_values,
- context=self.root_info.context,
- is_awaitable=self.root_info.is_awaitable,
- )
|