api.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. from collections import OrderedDict
  2. import pytz
  3. from django.conf import settings
  4. from django.contrib.contenttypes.models import ContentType
  5. from django.core.exceptions import ObjectDoesNotExist
  6. from django.db.models import ManyToManyField
  7. from django.http import Http404
  8. from rest_framework.exceptions import APIException
  9. from rest_framework.permissions import BasePermission
  10. from rest_framework.relations import PrimaryKeyRelatedField
  11. from rest_framework.response import Response
  12. from rest_framework.serializers import Field, ModelSerializer, ValidationError
  13. from rest_framework.viewsets import ModelViewSet as _ModelViewSet, ViewSet
  14. from .utils import dynamic_import
  15. class ServiceUnavailable(APIException):
  16. status_code = 503
  17. default_detail = "Service temporarily unavailable, please try again later."
  18. def get_serializer_for_model(model, prefix=''):
  19. """
  20. Dynamically resolve and return the appropriate serializer for a model.
  21. """
  22. app_name, model_name = model._meta.label.split('.')
  23. serializer_name = '{}.api.serializers.{}{}Serializer'.format(
  24. app_name, prefix, model_name
  25. )
  26. try:
  27. return dynamic_import(serializer_name)
  28. except AttributeError:
  29. return None
  30. #
  31. # Authentication
  32. #
  33. class IsAuthenticatedOrLoginNotRequired(BasePermission):
  34. """
  35. Returns True if the user is authenticated or LOGIN_REQUIRED is False.
  36. """
  37. def has_permission(self, request, view):
  38. if not settings.LOGIN_REQUIRED:
  39. return True
  40. return request.user.is_authenticated
  41. #
  42. # Fields
  43. #
  44. class ChoiceField(Field):
  45. """
  46. Represent a ChoiceField as {'value': <DB value>, 'label': <string>}.
  47. """
  48. def __init__(self, choices, **kwargs):
  49. self._choices = dict()
  50. for k, v in choices:
  51. # Unpack grouped choices
  52. if type(v) in [list, tuple]:
  53. for k2, v2 in v:
  54. self._choices[k2] = v2
  55. else:
  56. self._choices[k] = v
  57. super(ChoiceField, self).__init__(**kwargs)
  58. def to_representation(self, obj):
  59. if obj is '':
  60. return None
  61. data = OrderedDict([
  62. ('value', obj),
  63. ('label', self._choices[obj])
  64. ])
  65. return data
  66. def to_internal_value(self, data):
  67. # Hotwiring boolean values
  68. if hasattr(data, 'lower'):
  69. if data.lower() == 'true':
  70. return True
  71. if data.lower() == 'false':
  72. return False
  73. return data
  74. class ContentTypeField(Field):
  75. """
  76. Represent a ContentType as '<app_label>.<model>'
  77. """
  78. def to_representation(self, obj):
  79. return "{}.{}".format(obj.app_label, obj.model)
  80. def to_internal_value(self, data):
  81. app_label, model = data.split('.')
  82. try:
  83. return ContentType.objects.get_by_natural_key(app_label=app_label, model=model)
  84. except ContentType.DoesNotExist:
  85. raise ValidationError("Invalid content type")
  86. class TimeZoneField(Field):
  87. """
  88. Represent a pytz time zone.
  89. """
  90. def to_representation(self, obj):
  91. return obj.zone if obj else None
  92. def to_internal_value(self, data):
  93. if not data:
  94. return ""
  95. if data not in pytz.common_timezones:
  96. raise ValidationError('Unknown time zone "{}" (see pytz.common_timezones for all options)'.format(data))
  97. return pytz.timezone(data)
  98. class SerializedPKRelatedField(PrimaryKeyRelatedField):
  99. """
  100. Extends PrimaryKeyRelatedField to return a serialized object on read. This is useful for representing related
  101. objects in a ManyToManyField while still allowing a set of primary keys to be written.
  102. """
  103. def __init__(self, serializer, **kwargs):
  104. self.serializer = serializer
  105. self.pk_field = kwargs.pop('pk_field', None)
  106. super(SerializedPKRelatedField, self).__init__(**kwargs)
  107. def to_representation(self, value):
  108. return self.serializer(value, context={'request': self.context['request']}).data
  109. #
  110. # Serializers
  111. #
  112. # TODO: We should probably take a fresh look at exactly what we're doing with this. There might be a more elegant
  113. # way to enforce model validation on the serializer.
  114. class ValidatedModelSerializer(ModelSerializer):
  115. """
  116. Extends the built-in ModelSerializer to enforce calling clean() on the associated model during validation.
  117. """
  118. def validate(self, data):
  119. # Remove custom fields data and tags (if any) prior to model validation
  120. attrs = data.copy()
  121. attrs.pop('custom_fields', None)
  122. attrs.pop('tags', None)
  123. # Skip ManyToManyFields
  124. for field in self.Meta.model._meta.get_fields():
  125. if isinstance(field, ManyToManyField):
  126. attrs.pop(field.name, None)
  127. # Run clean() on an instance of the model
  128. if self.instance is None:
  129. instance = self.Meta.model(**attrs)
  130. else:
  131. instance = self.instance
  132. for k, v in attrs.items():
  133. setattr(instance, k, v)
  134. instance.clean()
  135. return data
  136. class WritableNestedSerializer(ModelSerializer):
  137. """
  138. Returns a nested representation of an object on read, but accepts only a primary key on write.
  139. """
  140. def run_validators(self, value):
  141. # DRF v3.8.2: Skip running validators on the data, since we only accept an integer PK instead of a dict. For
  142. # more context, see:
  143. # https://github.com/encode/django-rest-framework/pull/5922/commits/2227bc47f8b287b66775948ffb60b2d9378ac84f
  144. # https://github.com/encode/django-rest-framework/issues/6053
  145. return
  146. def to_internal_value(self, data):
  147. if data is None:
  148. return None
  149. try:
  150. return self.Meta.model.objects.get(pk=int(data))
  151. except (TypeError, ValueError):
  152. raise ValidationError("Primary key must be an integer")
  153. except ObjectDoesNotExist:
  154. raise ValidationError("Invalid ID")
  155. #
  156. # Viewsets
  157. #
  158. class ModelViewSet(_ModelViewSet):
  159. """
  160. Accept either a single object or a list of objects to create.
  161. """
  162. def get_serializer(self, *args, **kwargs):
  163. # If a list of objects has been provided, initialize the serializer with many=True
  164. if isinstance(kwargs.get('data', {}), list):
  165. kwargs['many'] = True
  166. return super(ModelViewSet, self).get_serializer(*args, **kwargs)
  167. def get_serializer_class(self):
  168. # If 'brief' has been passed as a query param, find and return the nested serializer for this model, if one
  169. # exists
  170. request = self.get_serializer_context()['request']
  171. if request.query_params.get('brief', False):
  172. serializer_class = get_serializer_for_model(self.queryset.model, prefix='Nested')
  173. if serializer_class is not None:
  174. return serializer_class
  175. # Fall back to the hard-coded serializer class
  176. return self.serializer_class
  177. class FieldChoicesViewSet(ViewSet):
  178. """
  179. Expose the built-in numeric values which represent static choices for a model's field.
  180. """
  181. permission_classes = [IsAuthenticatedOrLoginNotRequired]
  182. fields = []
  183. def __init__(self, *args, **kwargs):
  184. super(FieldChoicesViewSet, self).__init__(*args, **kwargs)
  185. # Compile a dict of all fields in this view
  186. self._fields = OrderedDict()
  187. for cls, field_list in self.fields:
  188. for field_name in field_list:
  189. model_name = cls._meta.verbose_name.lower().replace(' ', '-')
  190. key = ':'.join([model_name, field_name])
  191. choices = []
  192. for k, v in cls._meta.get_field(field_name).choices:
  193. if type(v) in [list, tuple]:
  194. for k2, v2 in v:
  195. choices.append({
  196. 'value': k2,
  197. 'label': v2,
  198. })
  199. else:
  200. choices.append({
  201. 'value': k,
  202. 'label': v,
  203. })
  204. self._fields[key] = choices
  205. def list(self, request):
  206. return Response(self._fields)
  207. def retrieve(self, request, pk):
  208. if pk not in self._fields:
  209. raise Http404
  210. return Response(self._fields[pk])
  211. def get_view_name(self):
  212. return "Field Choices"