api.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from __future__ import unicode_literals
  2. from collections import OrderedDict
  3. import pytz
  4. from django.conf import settings
  5. from django.contrib.contenttypes.models import ContentType
  6. from django.core.exceptions import ObjectDoesNotExist
  7. from django.db.models import ManyToManyField
  8. from django.http import Http404
  9. from rest_framework import mixins
  10. from rest_framework.exceptions import APIException
  11. from rest_framework.permissions import BasePermission
  12. from rest_framework.response import Response
  13. from rest_framework.serializers import Field, ModelSerializer, ValidationError
  14. from rest_framework.viewsets import GenericViewSet, ViewSet
  15. WRITE_OPERATIONS = ['create', 'update', 'partial_update', 'delete']
  16. class ServiceUnavailable(APIException):
  17. status_code = 503
  18. default_detail = "Service temporarily unavailable, please try again later."
  19. #
  20. # Authentication
  21. #
  22. class IsAuthenticatedOrLoginNotRequired(BasePermission):
  23. """
  24. Returns True if the user is authenticated or LOGIN_REQUIRED is False.
  25. """
  26. def has_permission(self, request, view):
  27. if not settings.LOGIN_REQUIRED:
  28. return True
  29. return request.user.is_authenticated
  30. #
  31. # Fields
  32. #
  33. class ChoiceFieldSerializer(Field):
  34. """
  35. Represent a ChoiceField as {'value': <DB value>, 'label': <string>}.
  36. """
  37. def __init__(self, choices, **kwargs):
  38. self._choices = dict()
  39. for k, v in choices:
  40. # Unpack grouped choices
  41. if type(v) in [list, tuple]:
  42. for k2, v2 in v:
  43. self._choices[k2] = v2
  44. else:
  45. self._choices[k] = v
  46. super(ChoiceFieldSerializer, self).__init__(**kwargs)
  47. def to_representation(self, obj):
  48. return {'value': obj, 'label': self._choices[obj]}
  49. def to_internal_value(self, data):
  50. return data
  51. class ContentTypeFieldSerializer(Field):
  52. """
  53. Represent a ContentType as '<app_label>.<model>'
  54. """
  55. def to_representation(self, obj):
  56. return "{}.{}".format(obj.app_label, obj.model)
  57. def to_internal_value(self, data):
  58. app_label, model = data.split('.')
  59. try:
  60. return ContentType.objects.get_by_natural_key(app_label=app_label, model=model)
  61. except ContentType.DoesNotExist:
  62. raise ValidationError("Invalid content type")
  63. class TimeZoneField(Field):
  64. """
  65. Represent a pytz time zone.
  66. """
  67. def to_representation(self, obj):
  68. return obj.zone if obj else None
  69. def to_internal_value(self, data):
  70. if not data:
  71. return ""
  72. try:
  73. return pytz.timezone(str(data))
  74. except pytz.exceptions.UnknownTimeZoneError:
  75. raise ValidationError('Invalid time zone "{}"'.format(data))
  76. #
  77. # Serializers
  78. #
  79. class ValidatedModelSerializer(ModelSerializer):
  80. """
  81. Extends the built-in ModelSerializer to enforce calling clean() on the associated model during validation.
  82. """
  83. def validate(self, data):
  84. # Remove custom field data (if any) prior to model validation
  85. attrs = data.copy()
  86. attrs.pop('custom_fields', None)
  87. # Run clean() on an instance of the model
  88. if self.instance is None:
  89. model = self.Meta.model
  90. # Ignore ManyToManyFields for new instances (a PK is needed for validation)
  91. for field in model._meta.get_fields():
  92. if isinstance(field, ManyToManyField) and field.name in attrs:
  93. attrs.pop(field.name)
  94. instance = self.Meta.model(**attrs)
  95. else:
  96. instance = self.instance
  97. for k, v in attrs.items():
  98. setattr(instance, k, v)
  99. instance.clean()
  100. return data
  101. class WritableNestedSerializer(ModelSerializer):
  102. """
  103. Returns a nested representation of an object on read, but accepts only a primary key on write.
  104. """
  105. def to_internal_value(self, data):
  106. try:
  107. return self.Meta.model.objects.get(pk=data)
  108. except ObjectDoesNotExist:
  109. raise ValidationError("Invalid ID")
  110. #
  111. # Viewsets
  112. #
  113. class ModelViewSet(mixins.CreateModelMixin,
  114. mixins.RetrieveModelMixin,
  115. mixins.UpdateModelMixin,
  116. mixins.DestroyModelMixin,
  117. mixins.ListModelMixin,
  118. GenericViewSet):
  119. """
  120. Substitute DRF's built-in ModelViewSet for our own, which introduces a bit of additional functionality:
  121. 1. Use an alternate serializer (if provided) for write operations
  122. 2. Accept either a single object or a list of objects to create
  123. """
  124. def get_serializer_class(self):
  125. # Check for a different serializer to use for write operations
  126. if self.action in WRITE_OPERATIONS and hasattr(self, 'write_serializer_class'):
  127. return self.write_serializer_class
  128. return self.serializer_class
  129. def get_serializer(self, *args, **kwargs):
  130. # If a list of objects has been provided, initialize the serializer with many=True
  131. if isinstance(kwargs.get('data', {}), list):
  132. kwargs['many'] = True
  133. return super(ModelViewSet, self).get_serializer(*args, **kwargs)
  134. class FieldChoicesViewSet(ViewSet):
  135. """
  136. Expose the built-in numeric values which represent static choices for a model's field.
  137. """
  138. permission_classes = [IsAuthenticatedOrLoginNotRequired]
  139. fields = []
  140. def __init__(self, *args, **kwargs):
  141. super(FieldChoicesViewSet, self).__init__(*args, **kwargs)
  142. # Compile a dict of all fields in this view
  143. self._fields = OrderedDict()
  144. for cls, field_list in self.fields:
  145. for field_name in field_list:
  146. model_name = cls._meta.verbose_name.lower().replace(' ', '-')
  147. key = ':'.join([model_name, field_name])
  148. choices = []
  149. for k, v in cls._meta.get_field(field_name).choices:
  150. if type(v) in [list, tuple]:
  151. for k2, v2 in v:
  152. choices.append({
  153. 'value': k2,
  154. 'label': v2,
  155. })
  156. else:
  157. choices.append({
  158. 'value': k,
  159. 'label': v,
  160. })
  161. self._fields[key] = choices
  162. def list(self, request):
  163. return Response(self._fields)
  164. def retrieve(self, request, pk):
  165. if pk not in self._fields:
  166. raise Http404
  167. return Response(self._fields[pk])
  168. def get_view_name(self):
  169. return "Field Choices"