Kaynağa Gözat

Change attr_type from list to str for MultipleChoiceFilter (#17638)

Jeremy Stretch 1 yıl önce
ebeveyn
işleme
f11dc00fae

+ 1 - 1
netbox/dcim/filtersets.py

@@ -271,7 +271,7 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM
 
     class Meta:
         model = Location
-        fields = ('id', 'name', 'slug', 'status', 'facility', 'description')
+        fields = ('id', 'name', 'slug', 'facility', 'description')
 
     def search(self, queryset, name, value):
         if not value.strip():

+ 4 - 4
netbox/netbox/graphql/filter_mixins.py

@@ -1,11 +1,12 @@
-from functools import partial, partialmethod, wraps
+from functools import partialmethod
 from typing import List
 
 import django_filters
 import strawberry
 import strawberry_django
-from django.core.exceptions import FieldDoesNotExist, ValidationError
+from django.core.exceptions import FieldDoesNotExist
 from strawberry import auto
+
 from ipam.fields import ASNField
 from netbox.graphql.scalars import BigInt
 from utilities.fields import ColorField, CounterCacheField
@@ -108,8 +109,7 @@ def map_strawberry_type(field):
     elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
         pass
     elif issubclass(type(field), django_filters.MultipleChoiceFilter):
-        should_create_function = True
-        attr_type = List[str] | None
+        attr_type = str | None
     elif issubclass(type(field), django_filters.TypedChoiceFilter):
         pass
     elif issubclass(type(field), django_filters.ChoiceFilter):

+ 37 - 4
netbox/netbox/tests/test_graphql.py

@@ -5,8 +5,8 @@ from django.urls import reverse
 from rest_framework import status
 
 from core.models import ObjectType
+from dcim.choices import LocationStatusChoices
 from dcim.models import Site, Location
-from ipam.models import ASN, RIR
 from users.models import ObjectPermission
 from utilities.testing import disable_warnings, APITestCase, TestCase
 
@@ -53,10 +53,27 @@ class GraphQLAPITestCase(APITestCase):
         sites = (
             Site(name='Site 1', slug='site-1'),
             Site(name='Site 2', slug='site-2'),
+            Site(name='Site 3', slug='site-3'),
         )
         Site.objects.bulk_create(sites)
-        Location.objects.create(site=sites[0], name='Location 1', slug='location-1'),
-        Location.objects.create(site=sites[1], name='Location 2', slug='location-2'),
+        Location.objects.create(
+            site=sites[0],
+            name='Location 1',
+            slug='location-1',
+            status=LocationStatusChoices.STATUS_PLANNED
+        ),
+        Location.objects.create(
+            site=sites[1],
+            name='Location 2',
+            slug='location-2',
+            status=LocationStatusChoices.STATUS_STAGING
+        ),
+        Location.objects.create(
+            site=sites[1],
+            name='Location 3',
+            slug='location-3',
+            status=LocationStatusChoices.STATUS_ACTIVE
+        ),
 
         # Add object-level permission
         obj_perm = ObjectPermission(
@@ -68,8 +85,9 @@ class GraphQLAPITestCase(APITestCase):
         obj_perm.object_types.add(ObjectType.objects.get_for_model(Location))
         obj_perm.object_types.add(ObjectType.objects.get_for_model(Site))
 
-        # A valid request should return the filtered list
         url = reverse('graphql')
+
+        # A valid request should return the filtered list
         query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}'
         response = self.client.post(url, data={'query': query}, format="json", **self.header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
@@ -78,6 +96,21 @@ class GraphQLAPITestCase(APITestCase):
         self.assertEqual(len(data['data']['location_list']), 1)
         self.assertIsNotNone(data['data']['location_list'][0]['site'])
 
+        # Test OR logic
+        query = """{
+            location_list( filters: {
+                status: \"""" + LocationStatusChoices.STATUS_PLANNED + """\",
+                OR: {status: \"""" + LocationStatusChoices.STATUS_STAGING + """\"}
+            }) {
+                id site {id}
+            }
+        }"""
+        response = self.client.post(url, data={'query': query}, format="json", **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = json.loads(response.content)
+        self.assertNotIn('errors', data)
+        self.assertEqual(len(data['data']['location_list']), 2)
+
         # An invalid request should return an empty list
         query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}'  # Invalid site ID
         response = self.client.post(url, data={'query': query}, format="json", **self.header)