Преглед изворни кода

fix(extras): Add choice_value lookup for ChoiceSetField (#22366)

Introduce ChoiceSetField as ArrayField subclass for custom field
choices and implement choice_value lookup to filter by value element
only. Update GraphQL filter to use ExtraChoicesLookup with contains and
length options.

Fixes #22324
Martin Hauser пре 1 месец
родитељ
комит
2e50fc3d97

+ 22 - 1
netbox/extras/fields.py

@@ -1,4 +1,10 @@
-from django.db.models import TextField
+from django.contrib.postgres.fields import ArrayField
+from django.db.models import CharField, TextField
+
+__all__ = (
+    'CachedValueField',
+    'ChoiceSetField',
+)
 
 
 class CachedValueField(TextField):
@@ -6,3 +12,18 @@ class CachedValueField(TextField):
     Currently a dummy field to prevent custom lookups being applied globally to TextField.
     """
     pass
+
+
+class ChoiceSetField(ArrayField):
+    """
+    An ArrayField of two-element [value, label] string pairs representing custom field choices.
+    """
+    def __init__(self, **kwargs):
+        kwargs['base_field'] = ArrayField(base_field=CharField(max_length=100), size=2)
+        super().__init__(**kwargs)
+
+    def deconstruct(self):
+        name, path, args, kwargs = super().deconstruct()
+        # base_field is fixed by __init__ and omitted from migrations
+        del kwargs['base_field']
+        return name, path, args, kwargs

+ 30 - 0
netbox/extras/graphql/filter_lookups.py

@@ -0,0 +1,30 @@
+import strawberry
+import strawberry_django
+from django.db.models import Q, QuerySet
+from strawberry.directive import DirectiveValue
+from strawberry.types import Info
+
+__all__ = (
+    'ExtraChoicesLookup',
+)
+
+
+@strawberry.input(
+    one_of=True,
+    description='Lookup for extra choices defined on a choice set. Only one of the lookup fields can be set.',
+)
+class ExtraChoicesLookup:
+    contains: str | None = strawberry.field(
+        default=strawberry.UNSET, description='Has an extra choice with this value'
+    )
+    length: int | None = strawberry.field(
+        default=strawberry.UNSET, description='Number of extra choices'
+    )
+
+    @strawberry_django.filter_field
+    def filter(self, info: Info, queryset: QuerySet, prefix: DirectiveValue[str] = '') -> tuple[QuerySet, Q]:
+        if self.contains is not strawberry.UNSET and self.contains is not None:
+            return queryset, Q(**{f'{prefix}choice_value': self.contains})
+        if self.length is not strawberry.UNSET and self.length is not None:
+            return queryset, Q(**{f'{prefix}len': self.length})
+        return queryset, Q()

+ 2 - 1
netbox/extras/graphql/filters.py

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
         SiteFilter,
         SiteGroupFilter,
     )
+    from extras.graphql.filter_lookups import ExtraChoicesLookup
     from netbox.graphql.enums import ColorEnum
     from netbox.graphql.filter_lookups import FloatLookup, IntegerLookup, JSONFilter, StringArrayLookup, TreeNodeFilter
     from tenancy.graphql.filters import TenantFilter, TenantGroupFilter
@@ -198,7 +199,7 @@ class CustomFieldChoiceSetFilter(ChangeLoggedModelFilter):
     ) = (
         strawberry_django.filter_field()
     )
-    extra_choices: Annotated['StringArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
+    extra_choices: Annotated['ExtraChoicesLookup', strawberry.lazy('extras.graphql.filter_lookups')] | None = (
         strawberry_django.filter_field()
     )
     order_alphabetically: FilterLookup[bool] | None = strawberry_django.filter_field()

+ 27 - 1
netbox/extras/lookups.py

@@ -3,7 +3,16 @@ from django.contrib.postgres.fields.ranges import RangeField
 from django.db.models import CharField, JSONField, Lookup
 from django.db.models.fields.json import KeyTextTransform
 
-from .fields import CachedValueField
+from .fields import CachedValueField, ChoiceSetField
+
+__all__ = (
+    'ChoiceValueLookup',
+    'Empty',
+    'JSONEmpty',
+    'NetContainsOrEquals',
+    'NetHost',
+    'RangeContains',
+)
 
 
 class RangeContains(Lookup):
@@ -34,6 +43,22 @@ class RangeContains(Lookup):
         return sql, params
 
 
+class ChoiceValueLookup(Lookup):
+    """
+    Match rows where any [value, label] pair in a ChoiceSetField has the given value.
+
+    Compares the RHS against the first element (the value) of each pair.
+    """
+    lookup_name = 'choice_value'
+    prepare_rhs = False
+
+    def as_sql(self, compiler, connection):
+        lhs, lhs_params = self.process_lhs(compiler, connection)
+        rhs, rhs_params = self.process_rhs(compiler, connection)
+        # Slice the value column of the two-dimensional array and match any element
+        return f'{rhs} = ANY({lhs}[:][1:1])', [*rhs_params, *lhs_params]
+
+
 class Empty(Lookup):
     """
     Filter on whether a string is empty.
@@ -99,6 +124,7 @@ class NetContainsOrEquals(Lookup):
 
 
 ArrayField.register_lookup(RangeContains)
+ChoiceSetField.register_lookup(ChoiceValueLookup)
 CharField.register_lookup(Empty)
 JSONField.register_lookup(JSONEmpty)
 CachedValueField.register_lookup(NetHost)

+ 18 - 0
netbox/extras/migrations/0139_alter_customfieldchoiceset_extra_choices.py

@@ -0,0 +1,18 @@
+from django.db import migrations
+
+import extras.fields
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('extras', '0138_customfieldchoiceset_choice_colors'),
+    ]
+
+    operations = [
+        migrations.AlterField(
+            model_name='customfieldchoiceset',
+            name='extra_choices',
+            field=extras.fields.ChoiceSetField(blank=True, null=True),
+        ),
+    ]

+ 2 - 6
netbox/extras/models/customfields.py

@@ -7,7 +7,6 @@ import django_filters
 import jsonschema
 from django import forms
 from django.conf import settings
-from django.contrib.postgres.fields import ArrayField
 from django.core.validators import RegexValidator, ValidationError
 from django.db import models
 from django.db.models import F, Func, Value
@@ -21,6 +20,7 @@ from jsonschema.exceptions import ValidationError as JSONValidationError
 from core.models import ObjectType
 from extras.choices import *
 from extras.data import CHOICE_SETS
+from extras.fields import ChoiceSetField
 from netbox.context import query_cache
 from netbox.models import ChangeLoggedModel
 from netbox.models.features import CloningMixin, ExportTemplatesMixin
@@ -877,11 +877,7 @@ class CustomFieldChoiceSet(CloningMixin, ExportTemplatesMixin, OwnerMixin, Chang
         null=True,
         help_text=_('Base set of predefined choices (optional)')
     )
-    extra_choices = ArrayField(
-        ArrayField(
-            base_field=models.CharField(max_length=100),
-            size=2
-        ),
+    extra_choices = ChoiceSetField(
         blank=True,
         null=True
     )

+ 58 - 0
netbox/extras/tests/test_api.py

@@ -325,6 +325,64 @@ class CustomFieldChoiceSetTestCase(APIViewTestCases.APIViewTestCase):
         response = self.client.post(self._get_list_url(), data, format='json', **self.header)
         self.assertEqual(response.status_code, 400)
 
+    def test_graphql_filter_extra_choices(self):
+        """Filter choice sets by choice value and by number of choices."""
+        self.add_permissions('extras.view_customfieldchoiceset')
+
+        # '1A' appears here only as a label, so it must not match contains
+        CustomFieldChoiceSet.objects.create(
+            name='Choice Set Labels',
+            extra_choices=[['sel1', 'Selection 1'], ['other', '1A']],
+        )
+
+        def run(lookup):
+            query = '{ custom_field_choice_set_list(filters: {extra_choices: ' + lookup + '}) { name } }'
+            response = self.client.post(reverse('graphql'), data={'query': query}, format='json', **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            data = response.json()
+            self.assertNotIn('errors', data)
+            return sorted(row['name'] for row in data['data']['custom_field_choice_set_list'])
+
+        # contains matches choice values only, never labels
+        self.assertEqual(run('{contains: "1A"}'), ['Choice Set 1'])
+        self.assertEqual(run('{contains: "sel1"}'), ['Choice Set Labels'])
+        self.assertEqual(run('{contains: "Selection 1"}'), [])
+        # length is the number of [value, label] pairs
+        self.assertEqual(run('{length: 2}'), ['Choice Set Labels'])
+        self.assertEqual(run('{length: 1}'), [])
+
+    def test_graphql_filter_extra_choices_rejects_array_operands(self):
+        """The legacy flat and nested array operand shapes fail schema validation."""
+        self.add_permissions('extras.view_customfieldchoiceset')
+
+        def run_invalid(lookup):
+            query = '{ custom_field_choice_set_list(filters: {extra_choices: ' + lookup + '}) { name } }'
+            response = self.client.post(reverse('graphql'), data={'query': query}, format='json', **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            self.assertIn('errors', response.json())
+
+        # shapes advertised or attempted before #22324
+        run_invalid('{contains: ["1A"]}')
+        run_invalid('{contains: [["1A", "Choice 1A"]]}')
+
+    def test_graphql_filter_extra_choices_via_relation(self):
+        """The extra_choices lookup composes through the choice_set relation prefix."""
+        self.add_permissions('extras.view_customfield')
+
+        for choice_set in CustomFieldChoiceSet.objects.filter(name__in=['Choice Set 1', 'Choice Set 2']):
+            CustomField.objects.create(
+                name=f'cf_{choice_set.name[-1]}',
+                type=CustomFieldTypeChoices.TYPE_SELECT,
+                choice_set=choice_set,
+            )
+
+        query = '{ custom_field_list(filters: {choice_set: {extra_choices: {contains: "1A"}}}) { name } }'
+        response = self.client.post(reverse('graphql'), data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+        self.assertEqual([row['name'] for row in data['data']['custom_field_list']], ['cf_1'])
+
 
 class CustomLinkTestCase(APIViewTestCases.APIViewTestCase):
     model = CustomLink

+ 31 - 0
netbox/extras/tests/test_lookups.py

@@ -0,0 +1,31 @@
+from django.core.exceptions import FieldError
+from django.test import TestCase
+
+from extras.choices import CustomFieldChoiceSetBaseChoices
+from extras.models import CustomFieldChoiceSet, EventRule
+
+
+class ChoiceValueLookupTestCase(TestCase):
+
+    def test_choice_value_matches_values_only(self):
+        """choice_value matches the value element of a pair, never the label."""
+        CustomFieldChoiceSet.objects.create(
+            name='Choice Set 1',
+            extra_choices=[['sel1', 'Selection 1'], ['other', 'sel2']],
+        )
+        self.assertEqual(CustomFieldChoiceSet.objects.filter(extra_choices__choice_value='sel1').count(), 1)
+        self.assertEqual(CustomFieldChoiceSet.objects.filter(extra_choices__choice_value='sel2').count(), 0)
+
+    def test_choice_value_excludes_null_extra_choices(self):
+        """Choice sets without extra choices are excluded without raising."""
+        CustomFieldChoiceSet.objects.create(
+            name='Base Only',
+            base_choices=CustomFieldChoiceSetBaseChoices.IATA,
+        )
+        self.assertEqual(CustomFieldChoiceSet.objects.filter(extra_choices__choice_value='sel1').count(), 0)
+        self.assertEqual(CustomFieldChoiceSet.objects.filter(extra_choices__len=2).count(), 0)
+
+    def test_choice_value_not_registered_on_plain_array_fields(self):
+        """choice_value is scoped to ChoiceSetField and unavailable on other ArrayFields."""
+        with self.assertRaises(FieldError):
+            EventRule.objects.filter(event_types__choice_value='x').exists()

+ 8 - 3
netbox/utilities/serializers/json.py

@@ -1,10 +1,15 @@
 from django.contrib.postgres.fields import ArrayField
-from django.core.serializers.json import Deserializer  # noqa: F401
+from django.core.serializers.json import Deserializer
 from django.core.serializers.json import Serializer as Serializer_
 from django.utils.encoding import is_protected_type
 
 # NOTE: Module must contain both Serializer and Deserializer
 
+__all__ = (
+    'Deserializer',
+    'Serializer',
+)
+
 
 class Serializer(Serializer_):
     """
@@ -14,8 +19,8 @@ class Serializer(Serializer_):
     def _value_from_field(self, obj, field):
         value = field.value_from_object(obj)
 
-        # Handle ArrayFields of protected types
-        if type(field) is ArrayField:
+        # Handle ArrayFields (including subclasses) of protected types
+        if isinstance(field, ArrayField):
             if not value or is_protected_type(value[0]):
                 return value
 

+ 2 - 2
netbox/utilities/testing/base.py

@@ -198,8 +198,8 @@ class ModelTestCase(TestCase):
                     model_dict[key] = [[r.lower, r.upper - 1] for r in value]
 
             else:
-                # Convert ArrayFields to CSV strings
-                if type(field) is ArrayField:
+                # Convert ArrayFields (including subclasses) to CSV strings
+                if isinstance(field, ArrayField):
                     if getattr(field.base_field, 'choices', None):
                         # Values for fields with pre-defined choices can be returned as lists
                         model_dict[key] = value

+ 12 - 1
netbox/utilities/tests/test_serialization.py

@@ -2,7 +2,8 @@ from django.test import TestCase
 
 from dcim.choices import SiteStatusChoices
 from dcim.models import Site
-from extras.models import Tag
+from extras.choices import CustomFieldChoiceSetBaseChoices
+from extras.models import CustomFieldChoiceSet, Tag
 from utilities.serialization import deserialize_object, serialize_object
 
 
@@ -32,6 +33,16 @@ class SerializationTestCase(TestCase):
         self.assertEqual(data['foo'], 123)
         self.assertNotIn('description', data)
 
+    def test_serialize_object_empty_array_field_subclass(self):
+        """An empty ArrayField subclass value serializes as a list, not a string."""
+        choice_set = CustomFieldChoiceSet.objects.create(
+            name='Choice Set 1',
+            base_choices=CustomFieldChoiceSetBaseChoices.IATA,
+            extra_choices=[],
+        )
+        data = serialize_object(choice_set)
+        self.assertEqual(data['extra_choices'], [])
+
     def test_deserialize_object(self):
         data = {
             'name': 'Site 1',