Selaa lähdekoodia

Added ViewTestCase (WIP)

Jeremy Stretch 6 vuotta sitten
vanhempi
commit
98cce7eee4
2 muutettua tiedostoa jossa 206 lisäystä ja 97 poistoa
  1. 60 97
      netbox/circuits/tests/test_views.py
  2. 146 0
      netbox/utilities/testing.py

+ 60 - 97
netbox/circuits/tests/test_views.py

@@ -1,58 +1,59 @@
-import urllib.parse
-
-from django.urls import reverse
+import datetime
 
 
+from circuits.choices import *
 from circuits.models import Circuit, CircuitType, Provider
 from circuits.models import Circuit, CircuitType, Provider
-from utilities.testing import TestCase
-
-
-class ProviderTestCase(TestCase):
-    user_permissions = (
-        'circuits.view_provider',
+from utilities.testing import ViewTestCase
+
+
+class ProviderTestCase(ViewTestCase):
+    model = Provider
+    form_data = {
+        'name': 'Provider X',
+        'slug': 'provider-x',
+        'asn': 65123,
+        'account': '1234',
+        'portal_url': 'http://example.com/portal',
+        'noc_contact': 'noc@example.com',
+        'admin_contact': 'admin@example.com',
+        'comments': 'Another provider',
+        'tags': 'Alpha,Bravo,Charlie',
+    }
+    csv_data = (
+        "name,slug",
+        "Provider 4,provider-4",
+        "Provider 5,provider-5",
+        "Provider 6,provider-6",
     )
     )
 
 
     @classmethod
     @classmethod
     def setUpTestData(cls):
     def setUpTestData(cls):
+
         Provider.objects.bulk_create([
         Provider.objects.bulk_create([
             Provider(name='Provider 1', slug='provider-1', asn=65001),
             Provider(name='Provider 1', slug='provider-1', asn=65001),
             Provider(name='Provider 2', slug='provider-2', asn=65002),
             Provider(name='Provider 2', slug='provider-2', asn=65002),
             Provider(name='Provider 3', slug='provider-3', asn=65003),
             Provider(name='Provider 3', slug='provider-3', asn=65003),
         ])
         ])
 
 
-    def test_provider_list(self):
-        url = reverse('circuits:provider_list')
-        params = {
-            "q": "test",
-        }
-
-        response = self.client.get('{}?{}'.format(url, urllib.parse.urlencode(params)))
-        self.assertHttpStatus(response, 200)
-
-    def test_provider(self):
-        provider = Provider.objects.first()
-        response = self.client.get(provider.get_absolute_url())
-        self.assertHttpStatus(response, 200)
-
-    def test_provider_import(self):
-        self.add_permissions('circuits.add_provider')
-        csv_data = (
-            "name,slug",
-            "Provider 4,provider-4",
-            "Provider 5,provider-5",
-            "Provider 6,provider-6",
-        )
 
 
-        response = self.client.post(reverse('circuits:provider_import'), {'csv': '\n'.join(csv_data)})
-
-        self.assertHttpStatus(response, 200)
-        self.assertEqual(Provider.objects.count(), 6)
-
-
-class CircuitTypeTestCase(TestCase):
-    user_permissions = (
-        'circuits.view_circuittype',
+class CircuitTypeTestCase(ViewTestCase):
+    model = CircuitType
+    views = ('list', 'add', 'edit', 'import')
+    form_data = {
+        'name': 'Circuit Type X',
+        'slug': 'circuit-type-x',
+        'description': 'A new circuit type',
+    }
+    csv_data = (
+        "name,slug",
+        "Circuit Type 4,circuit-type-4",
+        "Circuit Type 5,circuit-type-5",
+        "Circuit Type 6,circuit-type-6",
     )
     )
 
 
+    # Disable inapplicable tests
+    test_get_object = None
+    test_delete_object = None
+
     @classmethod
     @classmethod
     def setUpTestData(cls):
     def setUpTestData(cls):
 
 
@@ -62,32 +63,26 @@ class CircuitTypeTestCase(TestCase):
             CircuitType(name='Circuit Type 3', slug='circuit-type-3'),
             CircuitType(name='Circuit Type 3', slug='circuit-type-3'),
         ])
         ])
 
 
-    def test_circuittype_list(self):
-
-        url = reverse('circuits:circuittype_list')
-
-        response = self.client.get(url)
-        self.assertHttpStatus(response, 200)
-
-    def test_circuittype_import(self):
-        self.add_permissions('circuits.add_circuittype')
 
 
-        csv_data = (
-            "name,slug",
-            "Circuit Type 4,circuit-type-4",
-            "Circuit Type 5,circuit-type-5",
-            "Circuit Type 6,circuit-type-6",
-        )
-
-        response = self.client.post(reverse('circuits:circuittype_import'), {'csv': '\n'.join(csv_data)})
-
-        self.assertHttpStatus(response, 200)
-        self.assertEqual(CircuitType.objects.count(), 6)
-
-
-class CircuitTestCase(TestCase):
-    user_permissions = (
-        'circuits.view_circuit',
+class CircuitTestCase(ViewTestCase):
+    model = Circuit
+    # TODO: Determine how to lazily resolve related objects
+    form_data = {
+        'cid': 'Circuit X',
+        'provider': Provider.objects.first(),
+        'type': CircuitType.objects.first(),
+        'status': CircuitStatusChoices.STATUS_ACTIVE,
+        'tenant': None,
+        'install_date': datetime.date(2020, 1, 1),
+        'commit_rate': 1000,
+        'description': 'A new circuit',
+        'comments': 'Some comments',
+    }
+    csv_data = (
+        "cid,provider,type",
+        "Circuit 4,Provider 1,Circuit Type 1",
+        "Circuit 5,Provider 1,Circuit Type 1",
+        "Circuit 6,Provider 1,Circuit Type 1",
     )
     )
 
 
     @classmethod
     @classmethod
@@ -104,35 +99,3 @@ class CircuitTestCase(TestCase):
             Circuit(cid='Circuit 2', provider=provider, type=circuittype),
             Circuit(cid='Circuit 2', provider=provider, type=circuittype),
             Circuit(cid='Circuit 3', provider=provider, type=circuittype),
             Circuit(cid='Circuit 3', provider=provider, type=circuittype),
         ])
         ])
-
-    def test_circuit_list(self):
-
-        url = reverse('circuits:circuit_list')
-        params = {
-            "provider": Provider.objects.first().slug,
-            "type": CircuitType.objects.first().slug,
-        }
-
-        response = self.client.get('{}?{}'.format(url, urllib.parse.urlencode(params)))
-        self.assertHttpStatus(response, 200)
-
-    def test_circuit(self):
-
-        circuit = Circuit.objects.first()
-        response = self.client.get(circuit.get_absolute_url())
-        self.assertHttpStatus(response, 200)
-
-    def test_circuit_import(self):
-        self.add_permissions('circuits.add_circuit')
-
-        csv_data = (
-            "cid,provider,type",
-            "Circuit 4,Provider 1,Circuit Type 1",
-            "Circuit 5,Provider 1,Circuit Type 1",
-            "Circuit 6,Provider 1,Circuit Type 1",
-        )
-
-        response = self.client.post(reverse('circuits:circuit_import'), {'csv': '\n'.join(csv_data)})
-
-        self.assertHttpStatus(response, 200)
-        self.assertEqual(Circuit.objects.count(), 6)

+ 146 - 0
netbox/utilities/testing.py

@@ -2,8 +2,10 @@ import logging
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
 from django.contrib.auth.models import Permission, User
 from django.contrib.auth.models import Permission, User
+from django.core.exceptions import ObjectDoesNotExist
 from django.forms.models import model_to_dict as _model_to_dict
 from django.forms.models import model_to_dict as _model_to_dict
 from django.test import Client, TestCase as _TestCase
 from django.test import Client, TestCase as _TestCase
+from django.urls import reverse
 from rest_framework.test import APIClient
 from rest_framework.test import APIClient
 
 
 from users.models import Token
 from users.models import Token
@@ -70,6 +72,133 @@ class APITestCase(TestCase):
         self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
         self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
 
 
 
 
+# TODO: Omit this from tests
+class ViewTestCase(TestCase):
+    """
+    Stock TestCase suitable for testing all standard View functions:
+        - List objects
+        - View single object
+        - Create new object
+        - Modify existing object
+        - Delete existing object
+        - Import multiple new objects
+    """
+    model = None
+    form_data = {}
+    csv_data = {}
+
+    def __init__(self, *args, **kwargs):
+
+        super().__init__(*args, **kwargs)
+
+        self.base_url_name = '{}:{}_{{}}'.format(self.model._meta.app_label, self.model._meta.model_name)
+
+    def test_list_objects(self):
+        response = self.client.get(reverse(self.base_url_name.format('list')))
+        self.assertHttpStatus(response, 200)
+
+    def test_get_object(self):
+        instance = self.model.objects.first()
+        response = self.client.get(instance.get_absolute_url())
+        self.assertHttpStatus(response, 200)
+
+    def test_create_object(self):
+        initial_count = self.model.objects.count()
+        request = {
+            'path': reverse(self.base_url_name.format('add')),
+            'data': post_data(self.form_data),
+            'follow': True,
+        }
+        print(request['data'])
+
+        # Attempt to make the request without required permissions
+        with disable_warnings('django.request'):
+            self.assertHttpStatus(self.client.post(**request), 403)
+
+        # Assign the required permission and submit again
+        self.add_permissions('{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name))
+        response = self.client.post(**request)
+        self.assertHttpStatus(response, 200)
+
+        self.assertEqual(initial_count, self.model.objects.count() + 1)
+        instance = self.model.objects.order_by('-pk').first()
+        self.assertDictEqual(model_to_dict(instance), self.form_data)
+
+    def test_edit_object(self):
+        instance = self.model.objects.first()
+
+        # Determine the proper kwargs to pass to the edit URL
+        if hasattr(instance, 'slug'):
+            kwargs = {'slug': instance.slug}
+        else:
+            kwargs = {'pk': instance.pk}
+
+        request = {
+            'path': reverse(self.base_url_name.format('edit'), kwargs=kwargs),
+            'data': post_data(self.form_data),
+            'follow': True,
+        }
+
+        # Attempt to make the request without required permissions
+        with disable_warnings('django.request'):
+            self.assertHttpStatus(self.client.post(**request), 403)
+
+        # Assign the required permission and submit again
+        self.add_permissions('{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name))
+        response = self.client.post(**request)
+        self.assertHttpStatus(response, 200)
+
+        instance = self.model.objects.get(pk=instance.pk)
+        self.assertDictEqual(model_to_dict(instance), self.form_data)
+
+    def test_delete_object(self):
+        instance = self.model.objects.first()
+
+        # Determine the proper kwargs to pass to the deletion URL
+        if hasattr(instance, 'slug'):
+            kwargs = {'slug': instance.slug}
+        else:
+            kwargs = {'pk': instance.pk}
+
+        request = {
+            'path': reverse(self.base_url_name.format('delete'), kwargs=kwargs),
+            'data': {'confirm': True},
+            'follow': True,
+        }
+
+        # Attempt to make the request without required permissions
+        with disable_warnings('django.request'):
+            self.assertHttpStatus(self.client.post(**request), 403)
+
+        # Assign the required permission and submit again
+        self.add_permissions('{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name))
+        response = self.client.post(**request)
+        self.assertHttpStatus(response, 200)
+
+        with self.assertRaises(ObjectDoesNotExist):
+            self.model.objects.get(pk=instance.pk)
+
+    def test_import_objects(self):
+        request = {
+            'path': reverse(self.base_url_name.format('import')),
+            'data': {
+                'csv': '\n'.join(self.csv_data)
+            }
+        }
+        initial_count = self.model.objects.count()
+
+        # Attempt to make the request without required permissions
+        with disable_warnings('django.request'):
+            self.assertHttpStatus(self.client.post(**request), 403)
+
+        # Assign the required permission and submit again
+        self.add_permissions('{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name))
+        response = self.client.post(**request)
+        self.assertHttpStatus(response, 200)
+
+        self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1)
+
+
 def model_to_dict(instance, fields=None, exclude=None):
 def model_to_dict(instance, fields=None, exclude=None):
     """
     """
     Customized wrapper for Django's built-in model_to_dict(). Does the following:
     Customized wrapper for Django's built-in model_to_dict(). Does the following:
@@ -88,6 +217,23 @@ def model_to_dict(instance, fields=None, exclude=None):
     return model_dict
     return model_dict
 
 
 
 
+def post_data(data):
+    """
+    Take a dictionary of test data (suitable for comparison to an instance) and return a dict suitable for POSTing.
+    """
+    ret = {}
+
+    for key, value in data.items():
+        if value is None:
+            ret[key] = ''
+        elif hasattr(value, 'pk'):
+            ret[key] = getattr(value, 'pk')
+        else:
+            ret[key] = str(value)
+
+    return ret
+
+
 def create_test_user(username='testuser', permissions=list()):
 def create_test_user(username='testuser', permissions=list()):
     """
     """
     Create a User with the given permissions.
     Create a User with the given permissions.