graphql_optimizer.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import functools
  2. from django.core.exceptions import FieldDoesNotExist
  3. from django.db.models import ForeignKey
  4. from django.db.models.constants import LOOKUP_SEP
  5. from django.db.models.fields.reverse_related import ManyToOneRel
  6. from graphene import InputObjectType
  7. from graphene.types.generic import GenericScalar
  8. from graphene.types.resolver import default_resolver
  9. from graphene_django import DjangoObjectType
  10. from graphql import GraphQLResolveInfo, GraphQLSchema
  11. from graphql.execution.execute import get_field_def
  12. from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode, VariableNode
  13. from graphql.pyutils import Path
  14. from graphql.type.definition import GraphQLInterfaceType, GraphQLUnionType
  15. __all__ = (
  16. 'gql_query_optimizer',
  17. )
  18. def gql_query_optimizer(queryset, info, **options):
  19. return QueryOptimizer(info).optimize(queryset)
  20. class QueryOptimizer(object):
  21. def __init__(self, info, **options):
  22. self.root_info = info
  23. def optimize(self, queryset):
  24. info = self.root_info
  25. field_def = get_field_def(info.schema, info.parent_type, info.field_nodes[0])
  26. field_names = self._optimize_gql_selections(
  27. self._get_type(field_def),
  28. info.field_nodes[0],
  29. )
  30. qs = queryset.prefetch_related(*field_names)
  31. return qs
  32. def _get_type(self, field_def):
  33. a_type = field_def.type
  34. while hasattr(a_type, "of_type"):
  35. a_type = a_type.of_type
  36. return a_type
  37. def _get_graphql_schema(self, schema):
  38. if isinstance(schema, GraphQLSchema):
  39. return schema
  40. else:
  41. return schema.graphql_schema
  42. def _get_possible_types(self, graphql_type):
  43. if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
  44. graphql_schema = self._get_graphql_schema(self.root_info.schema)
  45. return graphql_schema.get_possible_types(graphql_type)
  46. else:
  47. return (graphql_type,)
  48. def _get_base_model(self, graphql_types):
  49. models = tuple(t.graphene_type._meta.model for t in graphql_types)
  50. for model in models:
  51. if all(issubclass(m, model) for m in models):
  52. return model
  53. return None
  54. def handle_inline_fragment(self, selection, schema, possible_types, field_names):
  55. fragment_type_name = selection.type_condition.name.value
  56. graphql_schema = self._get_graphql_schema(schema)
  57. fragment_type = graphql_schema.get_type(fragment_type_name)
  58. fragment_possible_types = self._get_possible_types(fragment_type)
  59. for fragment_possible_type in fragment_possible_types:
  60. fragment_model = fragment_possible_type.graphene_type._meta.model
  61. parent_model = self._get_base_model(possible_types)
  62. if not parent_model:
  63. continue
  64. path_from_parent = fragment_model._meta.get_path_from_parent(parent_model)
  65. select_related_name = LOOKUP_SEP.join(p.join_field.name for p in path_from_parent)
  66. if not select_related_name:
  67. continue
  68. sub_field_names = self._optimize_gql_selections(
  69. fragment_possible_type,
  70. selection,
  71. )
  72. field_names.append(select_related_name)
  73. return
  74. def handle_fragment_spread(self, field_names, name, field_type):
  75. fragment = self.root_info.fragments[name]
  76. sub_field_names = self._optimize_gql_selections(
  77. field_type,
  78. fragment,
  79. )
  80. def _optimize_gql_selections(self, field_type, field_ast):
  81. field_names = []
  82. selection_set = field_ast.selection_set
  83. if not selection_set:
  84. return field_names
  85. optimized_fields_by_model = {}
  86. schema = self.root_info.schema
  87. graphql_schema = self._get_graphql_schema(schema)
  88. graphql_type = graphql_schema.get_type(field_type.name)
  89. possible_types = self._get_possible_types(graphql_type)
  90. for selection in selection_set.selections:
  91. if isinstance(selection, InlineFragmentNode):
  92. self.handle_inline_fragment(selection, schema, possible_types, field_names)
  93. else:
  94. name = selection.name.value
  95. if isinstance(selection, FragmentSpreadNode):
  96. self.handle_fragment_spread(field_names, name, field_type)
  97. else:
  98. for possible_type in possible_types:
  99. selection_field_def = possible_type.fields.get(name)
  100. if not selection_field_def:
  101. continue
  102. graphene_type = possible_type.graphene_type
  103. model = getattr(graphene_type._meta, "model", None)
  104. if model and name not in optimized_fields_by_model:
  105. field_model = optimized_fields_by_model[name] = model
  106. if field_model == model:
  107. self._optimize_field(
  108. field_names,
  109. model,
  110. selection,
  111. selection_field_def,
  112. possible_type,
  113. )
  114. return field_names
  115. def _get_field_info(self, field_names, model, selection, field_def):
  116. name = None
  117. model_field = None
  118. name = self._get_name_from_resolver(field_def.resolve)
  119. if not name and callable(field_def.resolve) and not isinstance(field_def.resolve, functools.partial):
  120. name = selection.name.value
  121. if name:
  122. model_field = self._get_model_field_from_name(model, name)
  123. return (name, model_field)
  124. def _optimize_field(self, field_names, model, selection, field_def, parent_type):
  125. name, model_field = self._get_field_info(field_names, model, selection, field_def)
  126. if model_field:
  127. self._optimize_field_by_name(field_names, model, selection, field_def, name, model_field)
  128. return
  129. def _optimize_field_by_name(self, field_names, model, selection, field_def, name, model_field):
  130. if model_field.many_to_one or model_field.one_to_one:
  131. sub_field_names = self._optimize_gql_selections(
  132. self._get_type(field_def),
  133. selection,
  134. )
  135. if name not in field_names:
  136. field_names.append(name)
  137. for field in sub_field_names:
  138. prefetch_key = f"{name}__{field}"
  139. if prefetch_key not in field_names:
  140. field_names.append(prefetch_key)
  141. if model_field.one_to_many or model_field.many_to_many:
  142. sub_field_names = self._optimize_gql_selections(
  143. self._get_type(field_def),
  144. selection,
  145. )
  146. if isinstance(model_field, ManyToOneRel):
  147. sub_field_names.append(model_field.field.name)
  148. field_names.append(name)
  149. for field in sub_field_names:
  150. prefetch_key = f"{name}__{field}"
  151. if prefetch_key not in field_names:
  152. field_names.append(prefetch_key)
  153. return
  154. def _get_optimization_hints(self, resolver):
  155. return getattr(resolver, "optimization_hints", None)
  156. def _get_value(self, info, value):
  157. if isinstance(value, VariableNode):
  158. var_name = value.name.value
  159. value = info.variable_values.get(var_name)
  160. return value
  161. elif isinstance(value, InputObjectType):
  162. return value.__dict__
  163. else:
  164. return GenericScalar.parse_literal(value)
  165. def _get_name_from_resolver(self, resolver):
  166. optimization_hints = self._get_optimization_hints(resolver)
  167. if optimization_hints:
  168. name_fn = optimization_hints.model_field
  169. if name_fn:
  170. return name_fn()
  171. if self._is_resolver_for_id_field(resolver):
  172. return "id"
  173. elif isinstance(resolver, functools.partial):
  174. resolver_fn = resolver
  175. if resolver_fn.func != default_resolver:
  176. # Some resolvers have the partial function as the second
  177. # argument.
  178. for arg in resolver_fn.args:
  179. if isinstance(arg, (str, functools.partial)):
  180. break
  181. else:
  182. # No suitable instances found, default to first arg
  183. arg = resolver_fn.args[0]
  184. resolver_fn = arg
  185. if isinstance(resolver_fn, functools.partial) and resolver_fn.func == default_resolver:
  186. return resolver_fn.args[0]
  187. if self._is_resolver_for_id_field(resolver_fn):
  188. return "id"
  189. return resolver_fn
  190. def _is_resolver_for_id_field(self, resolver):
  191. resolve_id = DjangoObjectType.resolve_id
  192. return resolver == resolve_id
  193. def _get_model_field_from_name(self, model, name):
  194. try:
  195. return model._meta.get_field(name)
  196. except FieldDoesNotExist:
  197. descriptor = model.__dict__.get(name)
  198. if not descriptor:
  199. return None
  200. return getattr(descriptor, "rel", None) or getattr(descriptor, "related", None) # Django < 1.9
  201. def _is_foreign_key_id(self, model_field, name):
  202. return isinstance(model_field, ForeignKey) and model_field.name != name and model_field.get_attname() == name
  203. def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
  204. return GraphQLResolveInfo(
  205. field_name,
  206. field_asts,
  207. return_type,
  208. parent_type,
  209. Path(None, 0, None),
  210. schema=self.root_info.schema,
  211. fragments=self.root_info.fragments,
  212. root_value=self.root_info.root_value,
  213. operation=self.root_info.operation,
  214. variable_values=self.root_info.variable_values,
  215. context=self.root_info.context,
  216. is_awaitable=self.root_info.is_awaitable,
  217. )