Arthur 1 год назад
Родитель
Сommit
b75b9e01eb

+ 4 - 0
base_requirements.txt

@@ -131,6 +131,10 @@ social-auth-core
 # https://github.com/python-social-auth/social-app-django/blob/master/CHANGELOG.md
 social-auth-app-django
 
+# Strawberry GraphQL
+# https://github.com/strawberry-graphql/strawberry/blob/main/CHANGELOG.md
+strawberry-graphql
+
 # Strawberry GraphQL Django extension
 # https://github.com/strawberry-graphql/strawberry-django/blob/main/CHANGELOG.md
 strawberry-django

+ 6 - 6
netbox/circuits/graphql/schema.py

@@ -11,30 +11,30 @@ from .types import *
 class CircuitsQuery:
     @strawberry.field
     def circuit(self, id: int) -> CircuitType:
-        return models.Circuit.objects.get(id=id)
+        return models.Circuit.objects.get(pk=id)
     circuit_list: List[CircuitType] = strawberry_django.field()
 
     @strawberry.field
     def circuit_termination(self, id: int) -> CircuitTerminationType:
-        return models.CircuitTermination.objects.get(id=id)
+        return models.CircuitTermination.objects.get(pk=id)
     circuit_termination_list: List[CircuitTerminationType] = strawberry_django.field()
 
     @strawberry.field
     def circuit_type(self, id: int) -> CircuitTypeType:
-        return models.CircuitType.objects.get(id=id)
+        return models.CircuitType.objects.get(pk=id)
     circuit_type_list: List[CircuitTypeType] = strawberry_django.field()
 
     @strawberry.field
     def provider(self, id: int) -> ProviderType:
-        return models.Provider.objects.get(id=id)
+        return models.Provider.objects.get(pk=id)
     provider_list: List[ProviderType] = strawberry_django.field()
 
     @strawberry.field
     def provider_account(self, id: int) -> ProviderAccountType:
-        return models.ProviderAccount.objects.get(id=id)
+        return models.ProviderAccount.objects.get(pk=id)
     provider_account_list: List[ProviderAccountType] = strawberry_django.field()
 
     @strawberry.field
     def provider_network(self, id: int) -> ProviderNetworkType:
-        return models.ProviderNetwork.objects.get(id=id)
+        return models.ProviderNetwork.objects.get(pk=id)
     provider_network_list: List[ProviderNetworkType] = strawberry_django.field()

+ 2 - 2
netbox/core/graphql/schema.py

@@ -11,10 +11,10 @@ from .types import *
 class CoreQuery:
     @strawberry.field
     def data_file(self, id: int) -> DataFileType:
-        return models.DataFile.objects.get(id=id)
+        return models.DataFile.objects.get(pk=id)
     data_file_list: List[DataFileType] = strawberry_django.field()
 
     @strawberry.field
     def data_source(self, id: int) -> DataSourceType:
-        return models.DataSource.objects.get(id=id)
+        return models.DataSource.objects.get(pk=id)
     data_source_list: List[DataSourceType] = strawberry_django.field()

+ 1 - 1
netbox/core/graphql/types.py

@@ -15,7 +15,7 @@ __all__ = (
 
 @strawberry_django.type(
     models.DataFile,
-    exclude=('data',),
+    exclude=['data',],
     filters=DataFileFilter
 )
 class DataFileType(BaseObjectType):

+ 27 - 27
netbox/dcim/graphql/schema.py

@@ -11,137 +11,137 @@ from .types import *
 class DCIMQuery:
     @strawberry.field
     def cable(self, id: int) -> CableType:
-        return models.Cable.objects.get(id=id)
+        return models.Cable.objects.get(pk=id)
     cable_list: List[CableType] = strawberry_django.field()
 
     @strawberry.field
     def console_port(self, id: int) -> ConsolePortType:
-        return models.ConsolePort.objects.get(id=id)
+        return models.ConsolePort.objects.get(pk=id)
     console_port_list: List[ConsolePortType] = strawberry_django.field()
 
     @strawberry.field
     def console_port_template(self, id: int) -> ConsolePortTemplateType:
-        return models.ConsolePortTemplate.objects.get(id=id)
+        return models.ConsolePortTemplate.objects.get(pk=id)
     console_port_template_list: List[ConsolePortTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def console_server_port(self, id: int) -> ConsoleServerPortType:
-        return models.ConsoleServerPort.objects.get(id=id)
+        return models.ConsoleServerPort.objects.get(pk=id)
     console_server_port_list: List[ConsoleServerPortType] = strawberry_django.field()
 
     @strawberry.field
     def console_server_port_template(self, id: int) -> ConsoleServerPortTemplateType:
-        return models.ConsoleServerPortTemplate.objects.get(id=id)
+        return models.ConsoleServerPortTemplate.objects.get(pk=id)
     console_server_port_template_list: List[ConsoleServerPortTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def device(self, id: int) -> DeviceType:
-        return models.Device.objects.get(id=id)
+        return models.Device.objects.get(pk=id)
     device_list: List[DeviceType] = strawberry_django.field()
 
     @strawberry.field
     def device_bay(self, id: int) -> DeviceBayType:
-        return models.DeviceBay.objects.get(id=id)
+        return models.DeviceBay.objects.get(pk=id)
     device_bay_list: List[DeviceBayType] = strawberry_django.field()
 
     @strawberry.field
     def device_bay_template(self, id: int) -> DeviceBayTemplateType:
-        return models.DeviceBayTemplate.objects.get(id=id)
+        return models.DeviceBayTemplate.objects.get(pk=id)
     device_bay_template_list: List[DeviceBayTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def device_role(self, id: int) -> DeviceRoleType:
-        return models.DeviceRole.objects.get(id=id)
+        return models.DeviceRole.objects.get(pk=id)
     device_role_list: List[DeviceRoleType] = strawberry_django.field()
 
     @strawberry.field
     def device_type(self, id: int) -> DeviceTypeType:
-        return models.DeviceType.objects.get(id=id)
+        return models.DeviceType.objects.get(pk=id)
     device_type_list: List[DeviceTypeType] = strawberry_django.field()
 
     @strawberry.field
     def front_port(self, id: int) -> FrontPortType:
-        return models.FrontPort.objects.get(id=id)
+        return models.FrontPort.objects.get(pk=id)
     front_port_list: List[FrontPortType] = strawberry_django.field()
 
     @strawberry.field
     def front_port_template(self, id: int) -> FrontPortTemplateType:
-        return models.FrontPortTemplate.objects.get(id=id)
+        return models.FrontPortTemplate.objects.get(pk=id)
     front_port_template_list: List[FrontPortTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def interface(self, id: int) -> InterfaceType:
-        return models.Interface.objects.get(id=id)
+        return models.Interface.objects.get(pk=id)
     interface_list: List[InterfaceType] = strawberry_django.field()
 
     @strawberry.field
     def interface_template(self, id: int) -> InterfaceTemplateType:
-        return models.InterfaceTemplate.objects.get(id=id)
+        return models.InterfaceTemplate.objects.get(pk=id)
     interface_template_list: List[InterfaceTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def inventory_item(self, id: int) -> InventoryItemType:
-        return models.InventoryItem.objects.get(id=id)
+        return models.InventoryItem.objects.get(pk=id)
     inventory_item_list: List[InventoryItemType] = strawberry_django.field()
 
     @strawberry.field
     def inventory_item_role(self, id: int) -> InventoryItemRoleType:
-        return models.InventoryItemRole.objects.get(id=id)
+        return models.InventoryItemRole.objects.get(pk=id)
     inventory_item_role_list: List[InventoryItemRoleType] = strawberry_django.field()
 
     @strawberry.field
     def inventory_item_template(self, id: int) -> InventoryItemTemplateType:
-        return models.InventoryItemTemplate.objects.get(id=id)
+        return models.InventoryItemTemplate.objects.get(pk=id)
     inventory_item_template_list: List[InventoryItemTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def location(self, id: int) -> LocationType:
-        return models.Location.objects.get(id=id)
+        return models.Location.objects.get(pk=id)
     location_list: List[LocationType] = strawberry_django.field()
 
     @strawberry.field
     def manufacturer(self, id: int) -> ManufacturerType:
-        return models.Manufacturer.objects.get(id=id)
+        return models.Manufacturer.objects.get(pk=id)
     manufacturer_list: List[ManufacturerType] = strawberry_django.field()
 
     @strawberry.field
     def module(self, id: int) -> ModuleType:
-        return models.Module.objects.get(id=id)
+        return models.Module.objects.get(pk=id)
     module_list: List[ModuleType] = strawberry_django.field()
 
     @strawberry.field
     def module_bay(self, id: int) -> ModuleBayType:
-        return models.ModuleBay.objects.get(id=id)
+        return models.ModuleBay.objects.get(pk=id)
     module_bay_list: List[ModuleBayType] = strawberry_django.field()
 
     @strawberry.field
     def module_bay_template(self, id: int) -> ModuleBayTemplateType:
-        return models.ModuleBayTemplate.objects.get(id=id)
+        return models.ModuleBayTemplate.objects.get(pk=id)
     module_bay_template_list: List[ModuleBayTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def module_type(self, id: int) -> ModuleTypeType:
-        return models.ModuleType.objects.get(id=id)
+        return models.ModuleType.objects.get(pk=id)
     module_type_list: List[ModuleTypeType] = strawberry_django.field()
 
     @strawberry.field
     def platform(self, id: int) -> PlatformType:
-        return models.Platform.objects.get(id=id)
+        return models.Platform.objects.get(pk=id)
     platform_list: List[PlatformType] = strawberry_django.field()
 
     @strawberry.field
     def power_feed(self, id: int) -> PowerFeedType:
-        return models.PowerFeed.objects.get(id=id)
+        return models.PowerFeed.objects.get(pk=id)
     power_feed_list: List[PowerFeedType] = strawberry_django.field()
 
     @strawberry.field
     def power_outlet(self, id: int) -> PowerOutletType:
-        return models.PowerOutlet.objects.get(id=id)
+        return models.PowerOutlet.objects.get(pk=id)
     power_outlet_list: List[PowerOutletType] = strawberry_django.field()
 
     @strawberry.field
     def power_outlet_template(self, id: int) -> PowerOutletTemplateType:
-        return models.PowerOutletTemplate.objects.get(id=id)
+        return models.PowerOutletTemplate.objects.get(pk=id)
     power_outlet_template_list: List[PowerOutletTemplateType] = strawberry_django.field()
 
     @strawberry.field

+ 12 - 12
netbox/extras/graphql/schema.py

@@ -11,60 +11,60 @@ from .types import *
 class ExtrasQuery:
     @strawberry.field
     def config_context(self, id: int) -> ConfigContextType:
-        return models.ConfigContext.objects.get(id=id)
+        return models.ConfigContext.objects.get(pk=id)
     config_context_list: List[ConfigContextType] = strawberry_django.field()
 
     @strawberry.field
     def config_template(self, id: int) -> ConfigTemplateType:
-        return models.ConfigTemplate.objects.get(id=id)
+        return models.ConfigTemplate.objects.get(pk=id)
     config_template_list: List[ConfigTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def custom_field(self, id: int) -> CustomFieldType:
-        return models.CustomField.objects.get(id=id)
+        return models.CustomField.objects.get(pk=id)
     custom_field_list: List[CustomFieldType] = strawberry_django.field()
 
     @strawberry.field
     def custom_field_choice_set(self, id: int) -> CustomFieldChoiceSetType:
-        return models.CustomFieldChoiceSet.objects.get(id=id)
+        return models.CustomFieldChoiceSet.objects.get(pk=id)
     custom_field_choice_set_list: List[CustomFieldChoiceSetType] = strawberry_django.field()
 
     @strawberry.field
     def custom_link(self, id: int) -> CustomLinkType:
-        return models.CustomLink.objects.get(id=id)
+        return models.CustomLink.objects.get(pk=id)
     custom_link_list: List[CustomLinkType] = strawberry_django.field()
 
     @strawberry.field
     def export_template(self, id: int) -> ExportTemplateType:
-        return models.ExportTemplate.objects.get(id=id)
+        return models.ExportTemplate.objects.get(pk=id)
     export_template_list: List[ExportTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def image_attachment(self, id: int) -> ImageAttachmentType:
-        return models.ImageAttachment.objects.get(id=id)
+        return models.ImageAttachment.objects.get(pk=id)
     image_attachment_list: List[ImageAttachmentType] = strawberry_django.field()
 
     @strawberry.field
     def saved_filter(self, id: int) -> SavedFilterType:
-        return models.SavedFilter.objects.get(id=id)
+        return models.SavedFilter.objects.get(pk=id)
     saved_filter_list: List[SavedFilterType] = strawberry_django.field()
 
     @strawberry.field
     def journal_entry(self, id: int) -> JournalEntryType:
-        return models.JournalEntry.objects.get(id=id)
+        return models.JournalEntry.objects.get(pk=id)
     journal_entry_list: List[JournalEntryType] = strawberry_django.field()
 
     @strawberry.field
     def tag(self, id: int) -> TagType:
-        return models.Tag.objects.get(id=id)
+        return models.Tag.objects.get(pk=id)
     tag_list: List[TagType] = strawberry_django.field()
 
     @strawberry.field
     def webhook(self, id: int) -> WebhookType:
-        return models.Webhook.objects.get(id=id)
+        return models.Webhook.objects.get(pk=id)
     webhook_list: List[WebhookType] = strawberry_django.field()
 
     @strawberry.field
     def event_rule(self, id: int) -> EventRuleType:
-        return models.EventRule.objects.get(id=id)
+        return models.EventRule.objects.get(pk=id)
     event_rule_list: List[EventRuleType] = strawberry_django.field()

+ 16 - 16
netbox/ipam/graphql/schema.py

@@ -11,80 +11,80 @@ from .types import *
 class IPAMQuery:
     @strawberry.field
     def asn(self, id: int) -> ASNType:
-        return models.ASN.objects.get(id=id)
+        return models.ASN.objects.get(pk=id)
     asn_list: List[ASNType] = strawberry_django.field()
 
     @strawberry.field
     def asn_range(self, id: int) -> ASNRangeType:
-        return models.ASNRange.objects.get(id=id)
+        return models.ASNRange.objects.get(pk=id)
     asn_range_list: List[ASNRangeType] = strawberry_django.field()
 
     @strawberry.field
     def aggregate(self, id: int) -> AggregateType:
-        return models.Aggregate.objects.get(id=id)
+        return models.Aggregate.objects.get(pk=id)
     aggregate_list: List[AggregateType] = strawberry_django.field()
 
     @strawberry.field
     def ip_address(self, id: int) -> IPAddressType:
-        return models.IPAddress.objects.get(id=id)
+        return models.IPAddress.objects.get(pk=id)
     ip_address_list: List[IPAddressType] = strawberry_django.field()
 
     @strawberry.field
     def ip_range(self, id: int) -> IPRangeType:
-        return models.IPRange.objects.get(id=id)
+        return models.IPRange.objects.get(pk=id)
     ip_range_list: List[IPRangeType] = strawberry_django.field()
 
     @strawberry.field
     def prefix(self, id: int) -> PrefixType:
-        return models.Prefix.objects.get(id=id)
+        return models.Prefix.objects.get(pk=id)
     prefix_list: List[PrefixType] = strawberry_django.field()
 
     @strawberry.field
     def rir(self, id: int) -> RIRType:
-        return models.RIR.objects.get(id=id)
+        return models.RIR.objects.get(pk=id)
     rir_list: List[RIRType] = strawberry_django.field()
 
     @strawberry.field
     def role(self, id: int) -> RoleType:
-        return models.Role.objects.get(id=id)
+        return models.Role.objects.get(pk=id)
     role_list: List[RoleType] = strawberry_django.field()
 
     @strawberry.field
     def route_target(self, id: int) -> RouteTargetType:
-        return models.RouteTarget.objects.get(id=id)
+        return models.RouteTarget.objects.get(pk=id)
     route_target_list: List[RouteTargetType] = strawberry_django.field()
 
     @strawberry.field
     def service(self, id: int) -> ServiceType:
-        return models.Service.objects.get(id=id)
+        return models.Service.objects.get(pk=id)
     service_list: List[ServiceType] = strawberry_django.field()
 
     @strawberry.field
     def service_template(self, id: int) -> ServiceTemplateType:
-        return models.ServiceTemplate.objects.get(id=id)
+        return models.ServiceTemplate.objects.get(pk=id)
     service_template_list: List[ServiceTemplateType] = strawberry_django.field()
 
     @strawberry.field
     def fhrp_group(self, id: int) -> FHRPGroupType:
-        return models.FHRPGroup.objects.get(id=id)
+        return models.FHRPGroup.objects.get(pk=id)
     fhrp_group_list: List[FHRPGroupType] = strawberry_django.field()
 
     @strawberry.field
     def fhrp_group_assignment(self, id: int) -> FHRPGroupAssignmentType:
-        return models.FHRPGroupAssignment.objects.get(id=id)
+        return models.FHRPGroupAssignment.objects.get(pk=id)
     fhrp_group_assignment_list: List[FHRPGroupAssignmentType] = strawberry_django.field()
 
     @strawberry.field
     def vlan(self, id: int) -> VLANType:
-        return models.VLAN.objects.get(id=id)
+        return models.VLAN.objects.get(pk=id)
     vlan_list: List[VLANType] = strawberry_django.field()
 
     @strawberry.field
     def vlan_group(self, id: int) -> VLANGroupType:
-        return models.VLANGroup.objects.get(id=id)
+        return models.VLANGroup.objects.get(pk=id)
     vlan_group_list: List[VLANGroupType] = strawberry_django.field()
 
     @strawberry.field
     def vrf(self, id: int) -> VRFType:
-        return models.VRF.objects.get(id=id)
+        return models.VRF.objects.get(pk=id)
     vrf_list: List[VRFType] = strawberry_django.field()

+ 121 - 115
netbox/netbox/graphql/filter_mixins.py

@@ -11,6 +11,123 @@ from utilities.fields import ColorField, CounterCacheField
 from utilities.filters import *
 
 
+def map_strawberry_type(field):
+    should_create_function = False
+    attr_type = None
+
+    # NetBox Filter types - put base classes after derived classes
+    if isinstance(field, ContentTypeFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif isinstance(field, MACAddressFilter):
+        pass
+    elif isinstance(field, MultiValueArrayFilter):
+        pass
+    elif isinstance(field, MultiValueCharFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+    elif isinstance(field, MultiValueDateFilter):
+        attr_type = auto
+    elif isinstance(field, MultiValueDateTimeFilter):
+        attr_type = auto
+    elif isinstance(field, MultiValueDecimalFilter):
+        pass
+    elif isinstance(field, MultiValueMACAddressFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+    elif isinstance(field, MultiValueNumberFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+    elif isinstance(field, MultiValueTimeFilter):
+        pass
+    elif isinstance(field, MultiValueWWNFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+    elif isinstance(field, NullableCharFieldFilter):
+        pass
+    elif isinstance(field, NumericArrayFilter):
+        should_create_function = True
+        attr_type = int
+    elif isinstance(field, TreeNodeMultipleChoiceFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+
+    # From django_filters - ordering of these matters as base classes must
+    # come after derived classes so the base class doesn't get matched first
+    # a pass for the check (no attr_type) means we don't currently handle
+    # or use that type
+    elif issubclass(type(field), django_filters.OrderingFilter):
+        pass
+    elif issubclass(type(field), django_filters.BaseRangeFilter):
+        pass
+    elif issubclass(type(field), django_filters.BaseInFilter):
+        pass
+    elif issubclass(type(field), django_filters.LookupChoiceFilter):
+        pass
+    elif issubclass(type(field), django_filters.AllValuesMultipleFilter):
+        pass
+    elif issubclass(type(field), django_filters.AllValuesFilter):
+        pass
+    elif issubclass(type(field), django_filters.TimeRangeFilter):
+        pass
+    elif issubclass(type(field), django_filters.IsoDateTimeFromToRangeFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif issubclass(type(field), django_filters.DateTimeFromToRangeFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif issubclass(type(field), django_filters.DateFromToRangeFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif issubclass(type(field), django_filters.DateRangeFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif issubclass(type(field), django_filters.RangeFilter):
+        pass
+    elif issubclass(type(field), django_filters.NumericRangeFilter):
+        pass
+    elif issubclass(type(field), django_filters.NumberFilter):
+        should_create_function = True
+        attr_type = int
+    elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+    elif issubclass(type(field), django_filters.ModelChoiceFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif issubclass(type(field), django_filters.DurationFilter):
+        pass
+    elif issubclass(type(field), django_filters.IsoDateTimeFilter):
+        pass
+    elif issubclass(type(field), django_filters.DateTimeFilter):
+        attr_type = auto
+    elif issubclass(type(field), django_filters.TimeFilter):
+        attr_type = auto
+    elif issubclass(type(field), django_filters.DateFilter):
+        attr_type = auto
+    elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
+        pass
+    elif issubclass(type(field), django_filters.MultipleChoiceFilter):
+        should_create_function = True
+        attr_type = List[str] | None
+    elif issubclass(type(field), django_filters.TypedChoiceFilter):
+        pass
+    elif issubclass(type(field), django_filters.ChoiceFilter):
+        pass
+    elif issubclass(type(field), django_filters.BooleanFilter):
+        should_create_function = True
+        attr_type = bool | None
+    elif issubclass(type(field), django_filters.UUIDFilter):
+        should_create_function = True
+        attr_type = str | None
+    elif issubclass(type(field), django_filters.CharFilter):
+        # looks like only used by 'q'
+        should_create_function = True
+        attr_type = str | None
+
+    return should_create_function, attr_type
+
+
 def autotype_decorator(filterset):
     """
     Decorator used to auto creates a dataclass used by Strawberry based on a filterset.
@@ -36,10 +153,10 @@ def autotype_decorator(filterset):
         if fieldname not in cls.__annotations__ and attr_type:
             cls.__annotations__[fieldname] = attr_type
 
-        fname = f"filter_{fieldname}"
-        if should_create_function and not hasattr(cls, fname):
+        filter_name = f"filter_{fieldname}"
+        if should_create_function and not hasattr(cls, filter_name):
             filter_by_filterset = getattr(cls, 'filter_by_filterset')
-            setattr(cls, fname, partialmethod(filter_by_filterset, key=fieldname))
+            setattr(cls, filter_name, partialmethod(filter_by_filterset, key=fieldname))
 
     def wrapper(cls):
         cls.filterset = filterset
@@ -64,119 +181,8 @@ def autotype_decorator(filterset):
 
         declared_filters = filterset.declared_filters
         for fieldname, field in declared_filters.items():
-            should_create_function = False
-            attr_type = None
-
-            # NetBox Filter types - put base classes after derived classes
-            if isinstance(field, ContentTypeFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif isinstance(field, MACAddressFilter):
-                pass
-            elif isinstance(field, MultiValueArrayFilter):
-                pass
-            elif isinstance(field, MultiValueCharFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-            elif isinstance(field, MultiValueDateFilter):
-                attr_type = auto
-            elif isinstance(field, MultiValueDateTimeFilter):
-                attr_type = auto
-            elif isinstance(field, MultiValueDecimalFilter):
-                pass
-            elif isinstance(field, MultiValueMACAddressFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-            elif isinstance(field, MultiValueNumberFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-            elif isinstance(field, MultiValueTimeFilter):
-                pass
-            elif isinstance(field, MultiValueWWNFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-            elif isinstance(field, NullableCharFieldFilter):
-                pass
-            elif isinstance(field, NumericArrayFilter):
-                should_create_function = True
-                attr_type = int
-            elif isinstance(field, TreeNodeMultipleChoiceFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-
-            # From django_filters - ordering of these matters as base classes must
-            # come after derived classes so the base class doesn't get matched first
-            # a pass for the check (no attr_type) means we don't currently handle
-            # or use that type
-            elif issubclass(type(field), django_filters.OrderingFilter):
-                pass
-            elif issubclass(type(field), django_filters.BaseRangeFilter):
-                pass
-            elif issubclass(type(field), django_filters.BaseInFilter):
-                pass
-            elif issubclass(type(field), django_filters.LookupChoiceFilter):
-                pass
-            elif issubclass(type(field), django_filters.AllValuesMultipleFilter):
-                pass
-            elif issubclass(type(field), django_filters.AllValuesFilter):
-                pass
-            elif issubclass(type(field), django_filters.TimeRangeFilter):
-                pass
-            elif issubclass(type(field), django_filters.IsoDateTimeFromToRangeFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif issubclass(type(field), django_filters.DateTimeFromToRangeFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif issubclass(type(field), django_filters.DateFromToRangeFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif issubclass(type(field), django_filters.DateRangeFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif issubclass(type(field), django_filters.RangeFilter):
-                pass
-            elif issubclass(type(field), django_filters.NumericRangeFilter):
-                pass
-            elif issubclass(type(field), django_filters.NumberFilter):
-                should_create_function = True
-                attr_type = int
-            elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-            elif issubclass(type(field), django_filters.ModelChoiceFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif issubclass(type(field), django_filters.DurationFilter):
-                pass
-            elif issubclass(type(field), django_filters.IsoDateTimeFilter):
-                pass
-            elif issubclass(type(field), django_filters.DateTimeFilter):
-                attr_type = auto
-            elif issubclass(type(field), django_filters.TimeFilter):
-                attr_type = auto
-            elif issubclass(type(field), django_filters.DateFilter):
-                attr_type = auto
-            elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
-                pass
-            elif issubclass(type(field), django_filters.MultipleChoiceFilter):
-                should_create_function = True
-                attr_type = List[str] | None
-            elif issubclass(type(field), django_filters.TypedChoiceFilter):
-                pass
-            elif issubclass(type(field), django_filters.ChoiceFilter):
-                pass
-            elif issubclass(type(field), django_filters.BooleanFilter):
-                should_create_function = True
-                attr_type = bool | None
-            elif issubclass(type(field), django_filters.UUIDFilter):
-                should_create_function = True
-                attr_type = str | None
-            elif issubclass(type(field), django_filters.CharFilter):
-                # looks like only used by 'q'
-                should_create_function = True
-                attr_type = str | None
 
+            should_create_function, attr_type = map_strawberry_type(field)
             if attr_type is None:
                 raise NotImplementedError(f"GraphQL Filter field unknown: {fieldname}: {field}")
 

+ 6 - 6
netbox/tenancy/graphql/schema.py

@@ -11,30 +11,30 @@ from .types import *
 class TenancyQuery:
     @strawberry.field
     def tenant(self, id: int) -> TenantType:
-        return models.Tenant.objects.get(id=id)
+        return models.Tenant.objects.get(pk=id)
     tenant_list: List[TenantType] = strawberry_django.field()
 
     @strawberry.field
     def tenant_group(self, id: int) -> TenantGroupType:
-        return models.TenantGroup.objects.get(id=id)
+        return models.TenantGroup.objects.get(pk=id)
     tenant_group_list: List[TenantGroupType] = strawberry_django.field()
 
     @strawberry.field
     def contact(self, id: int) -> ContactType:
-        return models.Contact.objects.get(id=id)
+        return models.Contact.objects.get(pk=id)
     contact_list: List[ContactType] = strawberry_django.field()
 
     @strawberry.field
     def contact_role(self, id: int) -> ContactRoleType:
-        return models.ContactRole.objects.get(id=id)
+        return models.ContactRole.objects.get(pk=id)
     contact_role_list: List[ContactRoleType] = strawberry_django.field()
 
     @strawberry.field
     def contact_group(self, id: int) -> ContactGroupType:
-        return models.ContactGroup.objects.get(id=id)
+        return models.ContactGroup.objects.get(pk=id)
     contact_group_list: List[ContactGroupType] = strawberry_django.field()
 
     @strawberry.field
     def contact_assignment(self, id: int) -> ContactAssignmentType:
-        return models.ContactAssignment.objects.get(id=id)
+        return models.ContactAssignment.objects.get(pk=id)
     contact_assignment_list: List[ContactAssignmentType] = strawberry_django.field()

+ 1 - 8
netbox/tenancy/graphql/types.py

@@ -6,6 +6,7 @@ import strawberry_django
 from extras.graphql.mixins import CustomFieldsMixin, TagsMixin
 from netbox.graphql.types import BaseObjectType, OrganizationalObjectType, NetBoxObjectType
 from tenancy import models
+from .mixins import ContactAssignmentsMixin
 from .filters import *
 
 __all__ = (
@@ -18,14 +19,6 @@ __all__ = (
 )
 
 
-@strawberry.type
-class ContactAssignmentsMixin:
-
-    @strawberry_django.field
-    def assignments(self) -> List[Annotated["ContactAssignmentType", strawberry.lazy('tenancy.graphql.types')]]:
-        return self.assignments.all()
-
-
 #
 # Tenants
 #

+ 2 - 2
netbox/users/graphql/schema.py

@@ -12,10 +12,10 @@ from .types import *
 class UsersQuery:
     @strawberry.field
     def group(self, id: int) -> GroupType:
-        return models.Group.objects.get(id=id)
+        return models.Group.objects.get(pk=id)
     group_list: List[GroupType] = strawberry_django.field()
 
     @strawberry.field
     def user(self, id: int) -> UserType:
-        return models.User.objects.get(id=id)
+        return models.User.objects.get(pk=id)
     user_list: List[UserType] = strawberry_django.field()

+ 1 - 7
netbox/users/graphql/types.py

@@ -22,9 +22,7 @@ __all__ = (
     filters=GroupFilter
 )
 class GroupType:
-    @classmethod
-    def get_queryset(cls, queryset, info, **kwargs):
-        return RestrictedQuerySet(model=Group).restrict(info.context.request.user, 'view')
+    pass
 
 
 @strawberry_django.type(
@@ -36,10 +34,6 @@ class GroupType:
     filters=UserFilter
 )
 class UserType:
-    @classmethod
-    def get_queryset(cls, queryset, info, **kwargs):
-        return RestrictedQuerySet(model=get_user_model()).restrict(info.context.request.user, 'view')
-
     @strawberry_django.field
     def groups(self) -> List[GroupType]:
         return self.groups.all()

+ 4 - 4
netbox/utilities/testing/api.py

@@ -451,12 +451,12 @@ class APIViewTestCases:
             # Compile list of fields to include
             fields_string = ''
 
+            file_fields = (strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType)
             for field in type_class.__strawberry_definition__.fields:
                 if (
-                    field.type in (
-                        strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType) or
-                    type(field.type) is StrawberryOptional and field.type.of_type in (
-                        strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType)
+                    field.type in file_fields or (
+                        type(field.type) is StrawberryOptional and field.type.of_type in file_fields
+                    )
                 ):
                     # image / file fields nullable or not...
                     fields_string += f'{field.name} {{ name }}\n'

+ 6 - 6
netbox/virtualization/graphql/schema.py

@@ -11,30 +11,30 @@ from .types import *
 class VirtualizationQuery:
     @strawberry.field
     def cluster(self, id: int) -> ClusterType:
-        return models.Cluster.objects.get(id=id)
+        return models.Cluster.objects.get(pk=id)
     cluster_list: List[ClusterType] = strawberry_django.field()
 
     @strawberry.field
     def cluster_group(self, id: int) -> ClusterGroupType:
-        return models.ClusterGroup.objects.get(id=id)
+        return models.ClusterGroup.objects.get(pk=id)
     cluster_group_list: List[ClusterGroupType] = strawberry_django.field()
 
     @strawberry.field
     def cluster_type(self, id: int) -> ClusterTypeType:
-        return models.ClusterType.objects.get(id=id)
+        return models.ClusterType.objects.get(pk=id)
     cluster_type_list: List[ClusterTypeType] = strawberry_django.field()
 
     @strawberry.field
     def virtual_machine(self, id: int) -> VirtualMachineType:
-        return models.VirtualMachine.objects.get(id=id)
+        return models.VirtualMachine.objects.get(pk=id)
     virtual_machine_list: List[VirtualMachineType] = strawberry_django.field()
 
     @strawberry.field
     def vm_interface(self, id: int) -> VMInterfaceType:
-        return models.VMInterface.objects.get(id=id)
+        return models.VMInterface.objects.get(pk=id)
     vm_interface_list: List[VMInterfaceType] = strawberry_django.field()
 
     @strawberry.field
     def virtual_disk(self, id: int) -> VirtualDiskType:
-        return models.VirtualDisk.objects.get(id=id)
+        return models.VirtualDisk.objects.get(pk=id)
     virtual_disk_list: List[VirtualDiskType] = strawberry_django.field()

+ 10 - 10
netbox/vpn/graphql/schema.py

@@ -11,50 +11,50 @@ from .types import *
 class VPNQuery:
     @strawberry.field
     def ike_policy(self, id: int) -> IKEPolicyType:
-        return models.IKEPolicy.objects.get(id=id)
+        return models.IKEPolicy.objects.get(pk=id)
     ike_policy_list: List[IKEPolicyType] = strawberry_django.field()
 
     @strawberry.field
     def ike_proposal(self, id: int) -> IKEProposalType:
-        return models.IKEProposal.objects.get(id=id)
+        return models.IKEProposal.objects.get(pk=id)
     ike_proposal_list: List[IKEProposalType] = strawberry_django.field()
 
     @strawberry.field
     def ipsec_policy(self, id: int) -> IPSecPolicyType:
-        return models.IPSecPolicy.objects.get(id=id)
+        return models.IPSecPolicy.objects.get(pk=id)
     ipsec_policy_list: List[IPSecPolicyType] = strawberry_django.field()
 
     @strawberry.field
     def ipsec_profile(self, id: int) -> IPSecProfileType:
-        return models.IPSecProfile.objects.get(id=id)
+        return models.IPSecProfile.objects.get(pk=id)
     ipsec_profile_list: List[IPSecProfileType] = strawberry_django.field()
 
     @strawberry.field
     def ipsec_proposal(self, id: int) -> IPSecProposalType:
-        return models.IPSecProposal.objects.get(id=id)
+        return models.IPSecProposal.objects.get(pk=id)
     ipsec_proposal_list: List[IPSecProposalType] = strawberry_django.field()
 
     @strawberry.field
     def l2vpn(self, id: int) -> L2VPNType:
-        return models.L2VPN.objects.get(id=id)
+        return models.L2VPN.objects.get(pk=id)
     l2vpn_list: List[L2VPNType] = strawberry_django.field()
 
     @strawberry.field
     def l2vpn_termination(self, id: int) -> L2VPNTerminationType:
-        return models.L2VPNTermination.objects.get(id=id)
+        return models.L2VPNTermination.objects.get(pk=id)
     l2vpn_termination_list: List[L2VPNTerminationType] = strawberry_django.field()
 
     @strawberry.field
     def tunnel(self, id: int) -> TunnelType:
-        return models.Tunnel.objects.get(id=id)
+        return models.Tunnel.objects.get(pk=id)
     tunnel_list: List[TunnelType] = strawberry_django.field()
 
     @strawberry.field
     def tunnel_group(self, id: int) -> TunnelGroupType:
-        return models.TunnelGroup.objects.get(id=id)
+        return models.TunnelGroup.objects.get(pk=id)
     tunnel_group_list: List[TunnelGroupType] = strawberry_django.field()
 
     @strawberry.field
     def tunnel_termination(self, id: int) -> TunnelTerminationType:
-        return models.TunnelTermination.objects.get(id=id)
+        return models.TunnelTermination.objects.get(pk=id)
     tunnel_termination_list: List[TunnelTerminationType] = strawberry_django.field()

+ 3 - 3
netbox/wireless/graphql/schema.py

@@ -11,15 +11,15 @@ from .types import *
 class WirelessQuery:
     @strawberry.field
     def wireless_lan(self, id: int) -> WirelessLANType:
-        return models.WirelessLAN.objects.get(id=id)
+        return models.WirelessLAN.objects.get(pk=id)
     wireless_lan_list: List[WirelessLANType] = strawberry_django.field()
 
     @strawberry.field
     def wireless_lan_group(self, id: int) -> WirelessLANGroupType:
-        return models.WirelessLANGroup.objects.get(id=id)
+        return models.WirelessLANGroup.objects.get(pk=id)
     wireless_lan_group_list: List[WirelessLANGroupType] = strawberry_django.field()
 
     @strawberry.field
     def wireless_link(self, id: int) -> WirelessLinkType:
-        return models.WirelessLink.objects.get(id=id)
+        return models.WirelessLink.objects.get(pk=id)
     wireless_link_list: List[WirelessLinkType] = strawberry_django.field()

+ 2 - 1
requirements.txt

@@ -30,7 +30,8 @@ PyYAML==6.0.1
 requests==2.31.0
 social-auth-app-django==5.4.0
 social-auth-core[openidconnect]==4.5.3
-strawberry-graphql-django==0.33.0
+strawberry-graphql==0.220.0
+strawberry-graphql-django==0.35.1
 svgwrite==1.4.3
 tablib==3.5.0
 tzdata==2024.1