api.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from __future__ import unicode_literals
  2. from django.conf import settings
  3. from django.contrib.contenttypes.models import ContentType
  4. from rest_framework import authentication, exceptions
  5. from rest_framework.exceptions import APIException
  6. from rest_framework.pagination import LimitOffsetPagination
  7. from rest_framework.permissions import DjangoModelPermissions, SAFE_METHODS
  8. from rest_framework.serializers import Field, ValidationError
  9. from users.models import Token
  10. WRITE_OPERATIONS = ['create', 'update', 'partial_update', 'delete']
  11. class ServiceUnavailable(APIException):
  12. status_code = 503
  13. default_detail = "Service temporarily unavailable, please try again later."
  14. class TokenAuthentication(authentication.TokenAuthentication):
  15. """
  16. A custom authentication scheme which enforces Token expiration times.
  17. """
  18. model = Token
  19. def authenticate_credentials(self, key):
  20. model = self.get_model()
  21. try:
  22. token = model.objects.select_related('user').get(key=key)
  23. except model.DoesNotExist:
  24. raise exceptions.AuthenticationFailed("Invalid token")
  25. # Enforce the Token's expiration time, if one has been set.
  26. if token.is_expired:
  27. raise exceptions.AuthenticationFailed("Token expired")
  28. if not token.user.is_active:
  29. raise exceptions.AuthenticationFailed("User inactive")
  30. return token.user, token
  31. class TokenPermissions(DjangoModelPermissions):
  32. """
  33. Custom permissions handler which extends the built-in DjangoModelPermissions to validate a Token's write ability
  34. for unsafe requests (POST/PUT/PATCH/DELETE).
  35. """
  36. def __init__(self):
  37. # LOGIN_REQUIRED determines whether read-only access is provided to anonymous users.
  38. self.authenticated_users_only = settings.LOGIN_REQUIRED
  39. super(TokenPermissions, self).__init__()
  40. def has_permission(self, request, view):
  41. # If token authentication is in use, verify that the token allows write operations (for unsafe methods).
  42. if request.method not in SAFE_METHODS and isinstance(request.auth, Token):
  43. if not request.auth.write_enabled:
  44. return False
  45. return super(TokenPermissions, self).has_permission(request, view)
  46. class ChoiceFieldSerializer(Field):
  47. """
  48. Represent a ChoiceField as {'value': <DB value>, 'label': <string>}.
  49. """
  50. def __init__(self, choices, **kwargs):
  51. self._choices = dict()
  52. for k, v in choices:
  53. # Unpack grouped choices
  54. if type(v) in [list, tuple]:
  55. for k2, v2 in v:
  56. self._choices[k2] = v2
  57. else:
  58. self._choices[k] = v
  59. super(ChoiceFieldSerializer, self).__init__(**kwargs)
  60. def to_representation(self, obj):
  61. return {'value': obj, 'label': self._choices[obj]}
  62. def to_internal_value(self, data):
  63. return self._choices.get(data)
  64. class ContentTypeFieldSerializer(Field):
  65. """
  66. Represent a ContentType as '<app_label>.<model>'
  67. """
  68. def to_representation(self, obj):
  69. return "{}.{}".format(obj.app_label, obj.model)
  70. def to_internal_value(self, data):
  71. app_label, model = data.split('.')
  72. try:
  73. return ContentType.objects.get_by_natural_key(app_label=app_label, model=model)
  74. except ContentType.DoesNotExist:
  75. raise ValidationError("Invalid content type")
  76. class ModelValidationMixin(object):
  77. """
  78. Enforce a model's validation through clean() when validating serializer data. This is necessary to ensure we're
  79. employing the same validation logic via both forms and the API.
  80. """
  81. def validate(self, attrs):
  82. instance = self.Meta.model(**attrs)
  83. instance.clean()
  84. return attrs
  85. class WritableSerializerMixin(object):
  86. """
  87. Allow for the use of an alternate, writable serializer class for write operations (e.g. POST, PUT).
  88. """
  89. def get_serializer_class(self):
  90. if self.action in WRITE_OPERATIONS and hasattr(self, 'write_serializer_class'):
  91. return self.write_serializer_class
  92. return self.serializer_class
  93. class OptionalLimitOffsetPagination(LimitOffsetPagination):
  94. """
  95. Override the stock paginator to allow setting limit=0 to disable pagination for a request. This returns all objects
  96. matching a query, but retains the same format as a paginated request. The limit can only be disabled if
  97. MAX_PAGE_SIZE has been set to 0 or None.
  98. """
  99. def paginate_queryset(self, queryset, request, view=None):
  100. try:
  101. self.count = queryset.count()
  102. except (AttributeError, TypeError):
  103. self.count = len(queryset)
  104. self.limit = self.get_limit(request)
  105. self.offset = self.get_offset(request)
  106. self.request = request
  107. if self.limit and self.count > self.limit and self.template is not None:
  108. self.display_page_controls = True
  109. if self.count == 0 or self.offset > self.count:
  110. return list()
  111. if self.limit:
  112. return list(queryset[self.offset:self.offset + self.limit])
  113. else:
  114. return list(queryset[self.offset:])
  115. def get_limit(self, request):
  116. if self.limit_query_param:
  117. try:
  118. limit = int(request.query_params[self.limit_query_param])
  119. if limit < 0:
  120. raise ValueError()
  121. # Enforce maximum page size, if defined
  122. if settings.MAX_PAGE_SIZE:
  123. if limit == 0:
  124. return settings.MAX_PAGE_SIZE
  125. else:
  126. return min(limit, settings.MAX_PAGE_SIZE)
  127. return limit
  128. except (KeyError, ValueError):
  129. pass
  130. return self.default_limit