filters.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import django_filters
  2. from django import forms
  3. from django.conf import settings
  4. from django.db import models
  5. from dcim.forms import MACAddressField
  6. from extras.models import Tag
  7. def multivalue_field_factory(field_class):
  8. """
  9. Given a form field class, return a subclass capable of accepting multiple values. This allows us to OR on multiple
  10. filter values while maintaining the field's built-in validation. Example: GET /api/dcim/devices/?name=foo&name=bar
  11. """
  12. class NewField(field_class):
  13. widget = forms.SelectMultiple
  14. def to_python(self, value):
  15. if not value:
  16. return []
  17. return [
  18. # Only append non-empty values (this avoids e.g. trying to cast '' as an integer)
  19. super(field_class, self).to_python(v) for v in value if v
  20. ]
  21. return type('MultiValue{}'.format(field_class.__name__), (NewField,), dict())
  22. #
  23. # Filters
  24. #
  25. class MultiValueCharFilter(django_filters.MultipleChoiceFilter):
  26. field_class = multivalue_field_factory(forms.CharField)
  27. class MultiValueDateFilter(django_filters.MultipleChoiceFilter):
  28. field_class = multivalue_field_factory(forms.DateField)
  29. class MultiValueDateTimeFilter(django_filters.MultipleChoiceFilter):
  30. field_class = multivalue_field_factory(forms.DateTimeField)
  31. class MultiValueNumberFilter(django_filters.MultipleChoiceFilter):
  32. field_class = multivalue_field_factory(forms.IntegerField)
  33. class MultiValueTimeFilter(django_filters.MultipleChoiceFilter):
  34. field_class = multivalue_field_factory(forms.TimeField)
  35. class MACAddressFilter(django_filters.CharFilter):
  36. field_class = MACAddressField
  37. class MultiValueMACAddressFilter(django_filters.MultipleChoiceFilter):
  38. field_class = multivalue_field_factory(MACAddressField)
  39. class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter):
  40. """
  41. Filters for a set of Models, including all descendant models within a Tree. Example: [<Region: R1>,<Region: R2>]
  42. """
  43. def filter(self, qs, value):
  44. value = [node.get_descendants(include_self=True) for node in value]
  45. return super().filter(qs, value)
  46. class NumericInFilter(django_filters.BaseInFilter, django_filters.NumberFilter):
  47. """
  48. Filters for a set of numeric values. Example: id__in=100,200,300
  49. """
  50. pass
  51. class NullableCharFieldFilter(django_filters.CharFilter):
  52. """
  53. Allow matching on null field values by passing a special string used to signify NULL.
  54. """
  55. def filter(self, qs, value):
  56. if value != settings.FILTERS_NULL_CHOICE_VALUE:
  57. return super().filter(qs, value)
  58. qs = self.get_method(qs)(**{'{}__isnull'.format(self.field_name): True})
  59. return qs.distinct() if self.distinct else qs
  60. class TagFilter(django_filters.ModelMultipleChoiceFilter):
  61. """
  62. Match on one or more assigned tags. If multiple tags are specified (e.g. ?tag=foo&tag=bar), the queryset is filtered
  63. to objects matching all tags.
  64. """
  65. def __init__(self, *args, **kwargs):
  66. kwargs.setdefault('field_name', 'tags__slug')
  67. kwargs.setdefault('to_field_name', 'slug')
  68. kwargs.setdefault('conjoined', True)
  69. kwargs.setdefault('queryset', Tag.objects.all())
  70. super().__init__(*args, **kwargs)
  71. #
  72. # FilterSets
  73. #
  74. class NameSlugSearchFilterSet(django_filters.FilterSet):
  75. """
  76. A base class for adding the search method to models which only expose the `name` and `slug` fields
  77. """
  78. q = django_filters.CharFilter(
  79. method='search',
  80. label='Search',
  81. )
  82. def search(self, queryset, name, value):
  83. if not value.strip():
  84. return queryset
  85. return queryset.filter(
  86. models.Q(name__icontains=value) |
  87. models.Q(slug__icontains=value)
  88. )
  89. #
  90. # Update default filters
  91. #
  92. FILTER_DEFAULTS = django_filters.filterset.FILTER_FOR_DBFIELD_DEFAULTS
  93. FILTER_DEFAULTS.update({
  94. models.AutoField: {
  95. 'filter_class': MultiValueNumberFilter
  96. },
  97. models.CharField: {
  98. 'filter_class': MultiValueCharFilter
  99. },
  100. models.DateField: {
  101. 'filter_class': MultiValueDateFilter
  102. },
  103. models.DateTimeField: {
  104. 'filter_class': MultiValueDateTimeFilter
  105. },
  106. models.DecimalField: {
  107. 'filter_class': MultiValueNumberFilter
  108. },
  109. models.EmailField: {
  110. 'filter_class': MultiValueCharFilter
  111. },
  112. models.FloatField: {
  113. 'filter_class': MultiValueNumberFilter
  114. },
  115. models.IntegerField: {
  116. 'filter_class': MultiValueNumberFilter
  117. },
  118. models.PositiveIntegerField: {
  119. 'filter_class': MultiValueNumberFilter
  120. },
  121. models.PositiveSmallIntegerField: {
  122. 'filter_class': MultiValueNumberFilter
  123. },
  124. models.SlugField: {
  125. 'filter_class': MultiValueCharFilter
  126. },
  127. models.SmallIntegerField: {
  128. 'filter_class': MultiValueNumberFilter
  129. },
  130. models.TimeField: {
  131. 'filter_class': MultiValueTimeFilter
  132. },
  133. models.URLField: {
  134. 'filter_class': MultiValueCharFilter
  135. },
  136. })