api.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from __future__ import unicode_literals
  2. from collections import OrderedDict
  3. import pytz
  4. from django.conf import settings
  5. from django.contrib.contenttypes.models import ContentType
  6. from django.core.exceptions import ObjectDoesNotExist
  7. from django.db.models import ManyToManyField
  8. from django.http import Http404
  9. from rest_framework.exceptions import APIException
  10. from rest_framework.permissions import BasePermission
  11. from rest_framework.relations import PrimaryKeyRelatedField
  12. from rest_framework.response import Response
  13. from rest_framework.serializers import Field, ModelSerializer, ValidationError
  14. from rest_framework.viewsets import ModelViewSet as _ModelViewSet, ViewSet
  15. from .utils import dynamic_import
  16. class ServiceUnavailable(APIException):
  17. status_code = 503
  18. default_detail = "Service temporarily unavailable, please try again later."
  19. def get_serializer_for_model(model, prefix=''):
  20. """
  21. Dynamically resolve and return the appropriate serializer for a model.
  22. """
  23. app_name, model_name = model._meta.label.split('.')
  24. serializer_name = '{}.api.serializers.{}{}Serializer'.format(
  25. app_name, prefix, model_name
  26. )
  27. try:
  28. return dynamic_import(serializer_name)
  29. except AttributeError:
  30. return None
  31. #
  32. # Authentication
  33. #
  34. class IsAuthenticatedOrLoginNotRequired(BasePermission):
  35. """
  36. Returns True if the user is authenticated or LOGIN_REQUIRED is False.
  37. """
  38. def has_permission(self, request, view):
  39. if not settings.LOGIN_REQUIRED:
  40. return True
  41. return request.user.is_authenticated
  42. #
  43. # Fields
  44. #
  45. class ChoiceField(Field):
  46. """
  47. Represent a ChoiceField as {'value': <DB value>, 'label': <string>}.
  48. """
  49. def __init__(self, choices, **kwargs):
  50. self._choices = dict()
  51. for k, v in choices:
  52. # Unpack grouped choices
  53. if type(v) in [list, tuple]:
  54. for k2, v2 in v:
  55. self._choices[k2] = v2
  56. else:
  57. self._choices[k] = v
  58. super(ChoiceField, self).__init__(**kwargs)
  59. def to_representation(self, obj):
  60. return {'value': obj, 'label': self._choices[obj]}
  61. def to_internal_value(self, data):
  62. # Hotwiring boolean values
  63. if hasattr(data, 'lower'):
  64. if data.lower() == 'true':
  65. return True
  66. if data.lower() == 'false':
  67. return False
  68. return data
  69. class ContentTypeField(Field):
  70. """
  71. Represent a ContentType as '<app_label>.<model>'
  72. """
  73. def to_representation(self, obj):
  74. return "{}.{}".format(obj.app_label, obj.model)
  75. def to_internal_value(self, data):
  76. app_label, model = data.split('.')
  77. try:
  78. return ContentType.objects.get_by_natural_key(app_label=app_label, model=model)
  79. except ContentType.DoesNotExist:
  80. raise ValidationError("Invalid content type")
  81. class TimeZoneField(Field):
  82. """
  83. Represent a pytz time zone.
  84. """
  85. def to_representation(self, obj):
  86. return obj.zone if obj else None
  87. def to_internal_value(self, data):
  88. if not data:
  89. return ""
  90. try:
  91. return pytz.timezone(str(data))
  92. except pytz.exceptions.UnknownTimeZoneError:
  93. raise ValidationError('Invalid time zone "{}"'.format(data))
  94. class SerializedPKRelatedField(PrimaryKeyRelatedField):
  95. """
  96. Extends PrimaryKeyRelatedField to return a serialized object on read. This is useful for representing related
  97. objects in a ManyToManyField while still allowing a set of primary keys to be written.
  98. """
  99. def __init__(self, serializer, **kwargs):
  100. self.serializer = serializer
  101. self.pk_field = kwargs.pop('pk_field', None)
  102. super(SerializedPKRelatedField, self).__init__(**kwargs)
  103. def to_representation(self, value):
  104. return self.serializer(value, context={'request': self.context['request']}).data
  105. #
  106. # Serializers
  107. #
  108. # TODO: We should probably take a fresh look at exactly what we're doing with this. There might be a more elegant
  109. # way to enforce model validation on the serializer.
  110. class ValidatedModelSerializer(ModelSerializer):
  111. """
  112. Extends the built-in ModelSerializer to enforce calling clean() on the associated model during validation.
  113. """
  114. def validate(self, data):
  115. # Remove custom fields data and tags (if any) prior to model validation
  116. attrs = data.copy()
  117. attrs.pop('custom_fields', None)
  118. attrs.pop('tags', None)
  119. # Skip ManyToManyFields
  120. for field in self.Meta.model._meta.get_fields():
  121. if isinstance(field, ManyToManyField):
  122. attrs.pop(field.name, None)
  123. # Run clean() on an instance of the model
  124. if self.instance is None:
  125. instance = self.Meta.model(**attrs)
  126. else:
  127. instance = self.instance
  128. for k, v in attrs.items():
  129. setattr(instance, k, v)
  130. instance.clean()
  131. return data
  132. class WritableNestedSerializer(ModelSerializer):
  133. """
  134. Returns a nested representation of an object on read, but accepts only a primary key on write.
  135. """
  136. def to_internal_value(self, data):
  137. if data is None:
  138. return None
  139. try:
  140. return self.Meta.model.objects.get(pk=data)
  141. except ObjectDoesNotExist:
  142. raise ValidationError("Invalid ID")
  143. #
  144. # Viewsets
  145. #
  146. class ModelViewSet(_ModelViewSet):
  147. """
  148. Accept either a single object or a list of objects to create.
  149. """
  150. def get_serializer(self, *args, **kwargs):
  151. # If a list of objects has been provided, initialize the serializer with many=True
  152. if isinstance(kwargs.get('data', {}), list):
  153. kwargs['many'] = True
  154. return super(ModelViewSet, self).get_serializer(*args, **kwargs)
  155. class FieldChoicesViewSet(ViewSet):
  156. """
  157. Expose the built-in numeric values which represent static choices for a model's field.
  158. """
  159. permission_classes = [IsAuthenticatedOrLoginNotRequired]
  160. fields = []
  161. def __init__(self, *args, **kwargs):
  162. super(FieldChoicesViewSet, self).__init__(*args, **kwargs)
  163. # Compile a dict of all fields in this view
  164. self._fields = OrderedDict()
  165. for cls, field_list in self.fields:
  166. for field_name in field_list:
  167. model_name = cls._meta.verbose_name.lower().replace(' ', '-')
  168. key = ':'.join([model_name, field_name])
  169. choices = []
  170. for k, v in cls._meta.get_field(field_name).choices:
  171. if type(v) in [list, tuple]:
  172. for k2, v2 in v:
  173. choices.append({
  174. 'value': k2,
  175. 'label': v2,
  176. })
  177. else:
  178. choices.append({
  179. 'value': k,
  180. 'label': v,
  181. })
  182. self._fields[key] = choices
  183. def list(self, request):
  184. return Response(self._fields)
  185. def retrieve(self, request, pk):
  186. if pk not in self._fields:
  187. raise Http404
  188. return Response(self._fields[pk])
  189. def get_view_name(self):
  190. return "Field Choices"