Sfoglia il codice sorgente

alternative don't copy base code

Arthur 1 settimana fa
parent
commit
7a89ef193c
1 ha cambiato i file con 24 aggiunte e 33 eliminazioni
  1. 24 33
      netbox/users/api/views.py

+ 24 - 33
netbox/users/api/views.py

@@ -52,39 +52,30 @@ class TokenViewSet(NetBoxModelViewSet):
     serializer_class = serializers.TokenSerializer
     filterset_class = filtersets.TokenFilterSet
 
-    def create(self, request, *args, **kwargs):
-        # This is the same code as NetBoxModelViewSet.create(), but re-copies the plaintext token
-        # value(s) onto the re-fetched instance(s). The parent's create() re-fetches from the
-        # database after perform_create() to attach prefetched related objects, which discards
-        # the in-memory plaintext — for v2 tokens this value cannot be recovered later because
-        # the database stores only an HMAC digest.
-        serializer = self.get_serializer(data=request.data)
-        serializer.is_valid(raise_exception=True)
-        bulk_create = getattr(serializer, 'many', False)
-        self.perform_create(serializer)
-
-        # After creating the instance(s), re-initialize the serializer with a queryset
-        # to ensure related objects are prefetched.
-        if bulk_create:
-            instance_pks = [obj.pk for obj in serializer.instance]
-            # Capture the in-memory plaintext token values; v2 tokens cannot be recovered from the database.
-            plaintexts = {obj.pk: obj.token for obj in serializer.instance}
-            # Order by PK to ensure that the ordering of objects in the response
-            # matches the ordering of those in the request.
-            qs = list(self.get_queryset().filter(pk__in=instance_pks).order_by('pk'))
-            for obj in qs:
-                obj._token = plaintexts[obj.pk]
-        else:
-            # Capture the in-memory plaintext token; v2 tokens cannot be recovered from the database.
-            plaintext = serializer.instance.token
-            qs = self.get_queryset().get(pk=serializer.instance.pk)
-            qs._token = plaintext
-
-        # Re-serialize the instance(s) with prefetched data
-        serializer = self.get_serializer(qs, many=bulk_create)
-
-        headers = self.get_success_headers(serializer.data)
-        return Response(serializer.data, status=HTTP_201_CREATED, headers=headers)
+    def perform_create(self, serializer):
+        super().perform_create(serializer)
+        # The parent create() re-fetches from the database after perform_create() to attach
+        # prefetched relations, which discards the in-memory plaintext token. v2 plaintexts
+        # cannot be recovered later (only an HMAC digest is stored), so stash them here and
+        # restore them in get_serializer() before the response is serialized.
+        instances = serializer.instance if getattr(serializer, 'many', False) else [serializer.instance]
+        self._created_token_plaintexts = {obj.pk: obj.token for obj in instances}
+
+    def get_serializer(self, *args, **kwargs):
+        plaintexts = self.__dict__.pop('_created_token_plaintexts', None)
+        if plaintexts and args:
+            target = args[0]
+            if isinstance(target, Token):
+                target._token = plaintexts.get(target.pk, target._token)
+            else:
+                # Materialize the queryset and patch each instance, then pass the list along
+                # so the patched _token survives into to_representation().
+                objs = list(target)
+                for obj in objs:
+                    if obj.pk in plaintexts:
+                        obj._token = plaintexts[obj.pk]
+                args = (objs, *args[1:])
+        return super().get_serializer(*args, **kwargs)
 
 
 class TokenProvisionView(APIView):