filters.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import django_filters
  2. from django import forms
  3. from django.conf import settings
  4. from django.core.exceptions import ValidationError
  5. from django_filters.constants import EMPTY_VALUES
  6. from drf_spectacular.utils import extend_schema_field
  7. from drf_spectacular.types import OpenApiTypes
  8. __all__ = (
  9. 'ContentTypeFilter',
  10. 'MACAddressFilter',
  11. 'MultiValueCharFilter',
  12. 'MultiValueDateFilter',
  13. 'MultiValueDateTimeFilter',
  14. 'MultiValueDecimalFilter',
  15. 'MultiValueMACAddressFilter',
  16. 'MultiValueNumberFilter',
  17. 'MultiValueTimeFilter',
  18. 'MultiValueWWNFilter',
  19. 'NullableCharFieldFilter',
  20. 'NumericArrayFilter',
  21. 'TreeNodeMultipleChoiceFilter',
  22. )
  23. def multivalue_field_factory(field_class):
  24. """
  25. Given a form field class, return a subclass capable of accepting multiple values. This allows us to OR on multiple
  26. filter values while maintaining the field's built-in validation. Example: GET /api/dcim/devices/?name=foo&name=bar
  27. """
  28. class NewField(field_class):
  29. widget = forms.SelectMultiple
  30. def to_python(self, value):
  31. if not value:
  32. return []
  33. field = field_class()
  34. return [
  35. # Only append non-empty values (this avoids e.g. trying to cast '' as an integer)
  36. field.to_python(v) for v in value if v
  37. ]
  38. def run_validators(self, value):
  39. for v in value:
  40. super().run_validators(v)
  41. def validate(self, value):
  42. for v in value:
  43. super().validate(v)
  44. return type(f'MultiValue{field_class.__name__}', (NewField,), dict())
  45. #
  46. # Filters
  47. #
  48. @extend_schema_field(OpenApiTypes.STR)
  49. class MultiValueCharFilter(django_filters.MultipleChoiceFilter):
  50. field_class = multivalue_field_factory(forms.CharField)
  51. @extend_schema_field(OpenApiTypes.DATE)
  52. class MultiValueDateFilter(django_filters.MultipleChoiceFilter):
  53. field_class = multivalue_field_factory(forms.DateField)
  54. @extend_schema_field(OpenApiTypes.DATETIME)
  55. class MultiValueDateTimeFilter(django_filters.MultipleChoiceFilter):
  56. field_class = multivalue_field_factory(forms.DateTimeField)
  57. @extend_schema_field(OpenApiTypes.INT32)
  58. class MultiValueNumberFilter(django_filters.MultipleChoiceFilter):
  59. field_class = multivalue_field_factory(forms.IntegerField)
  60. @extend_schema_field(OpenApiTypes.DECIMAL)
  61. class MultiValueDecimalFilter(django_filters.MultipleChoiceFilter):
  62. field_class = multivalue_field_factory(forms.DecimalField)
  63. @extend_schema_field(OpenApiTypes.TIME)
  64. class MultiValueTimeFilter(django_filters.MultipleChoiceFilter):
  65. field_class = multivalue_field_factory(forms.TimeField)
  66. class MACAddressFilter(django_filters.CharFilter):
  67. pass
  68. @extend_schema_field(OpenApiTypes.STR)
  69. class MultiValueMACAddressFilter(django_filters.MultipleChoiceFilter):
  70. field_class = multivalue_field_factory(forms.CharField)
  71. def filter(self, qs, value):
  72. try:
  73. return super().filter(qs, value)
  74. except ValidationError:
  75. return qs.none()
  76. @extend_schema_field(OpenApiTypes.STR)
  77. class MultiValueWWNFilter(django_filters.MultipleChoiceFilter):
  78. field_class = multivalue_field_factory(forms.CharField)
  79. class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter):
  80. """
  81. Filters for a set of Models, including all descendant models within a Tree. Example: [<Region: R1>,<Region: R2>]
  82. """
  83. def get_filter_predicate(self, v):
  84. # Null value filtering
  85. if v is None:
  86. return {f"{self.field_name}__isnull": True}
  87. return super().get_filter_predicate(v)
  88. def filter(self, qs, value):
  89. value = [node.get_descendants(include_self=True) if not isinstance(node, str) else node for node in value]
  90. return super().filter(qs, value)
  91. class NullableCharFieldFilter(django_filters.CharFilter):
  92. """
  93. Allow matching on null field values by passing a special string used to signify NULL.
  94. """
  95. def filter(self, qs, value):
  96. if value != settings.FILTERS_NULL_CHOICE_VALUE:
  97. return super().filter(qs, value)
  98. qs = self.get_method(qs)(**{'{}__isnull'.format(self.field_name): True})
  99. return qs.distinct() if self.distinct else qs
  100. class NumericArrayFilter(django_filters.NumberFilter):
  101. """
  102. Filter based on the presence of an integer within an ArrayField.
  103. """
  104. def filter(self, qs, value):
  105. if value:
  106. value = [value]
  107. return super().filter(qs, value)
  108. class ContentTypeFilter(django_filters.CharFilter):
  109. """
  110. Allow specifying a ContentType by <app_label>.<model> (e.g. "dcim.site").
  111. """
  112. def filter(self, qs, value):
  113. if value in EMPTY_VALUES:
  114. return qs
  115. try:
  116. app_label, model = value.lower().split('.')
  117. except ValueError:
  118. return qs.none()
  119. return qs.filter(
  120. **{
  121. f'{self.field_name}__app_label': app_label,
  122. f'{self.field_name}__model': model
  123. }
  124. )