Ver Fonte

Add tests for new functionality in device.py

The asynchronous loop is a bit hard to test.  It may need refactoring
into smaller, testable functions to improve the coverage further.

Don't register a listener for starting the loop later when Home
Assistant is shutting down.  Probably not important, as the listener
will not be notified, but seems cleaner.
Jason Rumney há 3 anos atrás
pai
commit
d94c2bc83a
2 ficheiros alterados com 283 adições e 39 exclusões
  1. 3 1
      custom_components/tuya_local/device.py
  2. 280 38
      tests/test_device.py

+ 3 - 1
custom_components/tuya_local/device.py

@@ -112,7 +112,9 @@ class TuyaLocalDevice(object):
         self._refresh_task = self._hass.async_create_task(self.receive_loop())
 
     def start(self):
-        if self._hass.is_running and not self._hass.is_stopping:
+        if self._hass.is_stopping:
+            return
+        elif self._hass.is_running:
             if self._startup_listener:
                 self._startup_listener()
                 self._startup_listener = None

+ 280 - 38
tests/test_device.py

@@ -1,8 +1,12 @@
-import tinytuya
 from datetime import datetime
 from time import time
 from unittest import IsolatedAsyncioTestCase
-from unittest.mock import AsyncMock, call, patch
+from unittest.mock import AsyncMock, Mock, call, patch
+
+from homeassistant.const import (
+    EVENT_HOMEASSISTANT_STARTED,
+    EVENT_HOMEASSISTANT_STOP,
+)
 
 from custom_components.tuya_local.device import TuyaLocalDevice
 
@@ -21,6 +25,10 @@ class TestDevice(IsolatedAsyncioTestCase):
         self.addCleanup(hass_patcher.stop)
         self.hass = hass_patcher.start()
 
+        lock_patcher = patch("custom_components.tuya_local.device.Lock")
+        self.addCleanup(lock_patcher.stop)
+        self.mock_lock = lock_patcher.start()
+
         self.subject = TuyaLocalDevice(
             "Some name",
             "some_dev_id",
@@ -71,14 +79,14 @@ class TestDevice(IsolatedAsyncioTestCase):
 
         self.subject.async_refresh.assert_awaited()
 
-    async def test_detection_returns_none_when_device_type_could_not_be_detected(self):
+    async def test_detection_returns_none_when_device_type_not_detected(self):
         self.subject._cached_state = {"2": False, "updated_at": datetime.now()}
         self.assertEqual(await self.subject.async_inferred_type(), None)
 
     async def test_refreshes_when_there_is_no_pending_reset(self):
         async_job = AsyncMock()
         self.subject._cached_state = {"updated_at": time() - 19}
-        self.subject._hass.async_add_executor_job.return_value = awaitable = async_job()
+        self.hass().async_add_executor_job.return_value = async_job()
         await self.subject.async_refresh()
 
         async_job.assert_awaited()
@@ -86,24 +94,22 @@ class TestDevice(IsolatedAsyncioTestCase):
     async def test_refreshes_when_there_is_expired_pending_reset(self):
         async_job = AsyncMock()
         self.subject._cached_state = {"updated_at": time() - 20}
-        self.subject._hass.async_add_executor_job.return_value = awaitable = async_job()
+        self.hass().async_add_executor_job.return_value = async_job()
         await self.subject.async_refresh()
 
         async_job.assert_awaited()
 
     async def test_refresh_reloads_status_from_device(self):
-        self.subject._hass.async_add_executor_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.return_value = awaitable = {
-            "dps": {"1": False}
-        }
+        self.hass().async_add_executor_job = AsyncMock()
+        self.hass().async_add_executor_job.return_value = {"dps": {"1": False}}
 
         await self.subject.async_refresh()
 
-        self.subject._hass.async_add_executor_job.assert_called_once()
+        self.hass().async_add_executor_job.assert_called_once()
 
     async def test_refresh_retries_up_to_nine_times(self):
-        self.subject._hass.async_add_executor_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.side_effect = [
+        self.hass().async_add_executor_job = AsyncMock()
+        self.hass().async_add_executor_job.side_effect = [
             Exception("Error"),
             Exception("Error"),
             Exception("Error"),
@@ -117,18 +123,18 @@ class TestDevice(IsolatedAsyncioTestCase):
 
         await self.subject.async_refresh()
 
-        self.assertEqual(self.subject._hass.async_add_executor_job.call_count, 9)
+        self.assertEqual(self.hass().async_add_executor_job.call_count, 9)
         # self.assertEqual(self.subject._cached_state["1"], False)
 
-    async def test_refresh_clears_cached_state_and_pending_updates_after_failing_nine_times(
+    async def test_refresh_clears_cached_and_pending_after_nine_fails(
         self,
     ):
         self.subject._cached_state = {"1": True}
         self.subject._pending_updates = {
             "1": {"value": False, "updated_at": datetime.now(), "sent": True}
         }
-        self.subject._hass.async_add_executor_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.side_effect = [
+        self.hass().async_add_executor_job = AsyncMock()
+        self.hass().async_add_executor_job.side_effect = [
             Exception("Error"),
             Exception("Error"),
             Exception("Error"),
@@ -142,14 +148,14 @@ class TestDevice(IsolatedAsyncioTestCase):
 
         await self.subject.async_refresh()
 
-        self.assertEqual(self.subject._hass.async_add_executor_job.call_count, 9)
+        self.assertEqual(self.hass().async_add_executor_job.call_count, 9)
         self.assertEqual(self.subject._cached_state, {"updated_at": 0})
         self.assertEqual(self.subject._pending_updates, {})
 
     async def test_api_protocol_version_is_rotated_with_each_failure(self):
-        self.subject._api.set_version.reset_mock()
-        self.subject._hass.async_add_executor_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.side_effect = [
+        self.mock_api().set_version.reset_mock()
+        self.hass().async_add_executor_job = AsyncMock()
+        self.hass().async_add_executor_job.side_effect = [
             Exception("Error"),
             Exception("Error"),
             Exception("Error"),
@@ -159,14 +165,14 @@ class TestDevice(IsolatedAsyncioTestCase):
         ]
         await self.subject.async_refresh()
 
-        self.subject._api.set_version.assert_has_calls(
+        self.mock_api().set_version.assert_has_calls(
             [call(3.1), call(3.2), call(3.4), call(3.3), call(3.1)]
         )
 
     async def test_api_protocol_version_is_stable_once_successful(self):
-        self.subject._api.set_version.reset_mock()
-        self.subject._hass.async_add_executor_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.side_effect = [
+        self.mock_api().set_version.reset_mock()
+        self.hass().async_add_executor_job = AsyncMock()
+        self.hass().async_add_executor_job.side_effect = [
             Exception("Error"),
             Exception("Error"),
             Exception("Error"),
@@ -185,20 +191,24 @@ class TestDevice(IsolatedAsyncioTestCase):
         await self.subject.async_refresh()
         self.assertEqual(self.subject._api_protocol_version_index, 3)
 
-        self.subject._api.set_version.assert_has_calls(
-            [call(3.1), call(3.2), call(3.4)]
+        self.mock_api().set_version.assert_has_calls(
+            [
+                call(3.1),
+                call(3.2),
+                call(3.4),
+            ]
         )
 
     async def test_api_protocol_version_is_not_rotated_when_not_auto(self):
         self.subject._protocol_configured = 3.4
         self.subject._api_protocol_version_index = None
-        self.subject._api.set_version.reset_mock()
+        self.mock_api().set_version.reset_mock()
         self.subject._rotate_api_protocol_version()
-        self.subject._api.set_version.assert_called_once_with(3.4)
-        self.subject._api.set_version.reset_mock()
+        self.mock_api().set_version.assert_called_once_with(3.4)
+        self.mock_api().set_version.reset_mock()
 
-        self.subject._hass.async_add_executor_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.side_effect = [
+        self.hass().async_add_executor_job = AsyncMock()
+        self.hass().async_add_executor_job.side_effect = [
             Exception("Error"),
             Exception("Error"),
             Exception("Error"),
@@ -263,19 +273,19 @@ class TestDevice(IsolatedAsyncioTestCase):
 
     async def test_async_set_property_schedules_job(self):
         async_job = AsyncMock()
-        self.subject._hass.async_add_executor_job.return_value = awaitable = async_job()
+        self.hass().async_add_executor_job.return_value = async_job()
 
         await self.subject.async_set_property("1", False)
 
-        self.subject._hass.async_add_executor_job.assert_called_once()
+        self.hass().async_add_executor_job.assert_called_once()
         async_job.assert_awaited()
 
-    async def test_set_property_immediately_stores_new_value_to_pending_updates(self):
+    async def test_set_property_immediately_stores_pending_updates(self):
         self.subject._cached_state = {"1": True}
         await self.subject.async_set_property("1", False)
         self.assertFalse(self.subject.get_property("1"))
 
-    async def test_set_properties_takes_no_action_when_no_properties_are_provided(self):
+    async def test_set_properties_takes_no_action_when_nothing_provided(self):
         with patch("asyncio.sleep") as mock:
             await self.subject.async_set_properties({})
             mock.assert_not_called()
@@ -288,11 +298,243 @@ class TestDevice(IsolatedAsyncioTestCase):
     def test_get_key_for_value_returns_key_from_object_matching_value(self):
         obj = {"key1": "value1", "key2": "value2"}
 
-        self.assertEqual(TuyaLocalDevice.get_key_for_value(obj, "value1"), "key1")
-        self.assertEqual(TuyaLocalDevice.get_key_for_value(obj, "value2"), "key2")
+        self.assertEqual(
+            TuyaLocalDevice.get_key_for_value(obj, "value1"),
+            "key1",
+        )
+        self.assertEqual(
+            TuyaLocalDevice.get_key_for_value(obj, "value2"),
+            "key2",
+        )
 
     def test_get_key_for_value_returns_fallback_when_value_not_found(self):
         obj = {"key1": "value1", "key2": "value2"}
         self.assertEqual(
-            TuyaLocalDevice.get_key_for_value(obj, "value3", fallback="fb"), "fb"
+            TuyaLocalDevice.get_key_for_value(obj, "value3", fallback="fb"),
+            "fb",
         )
+
+    def test_refresh_cached_state(self):
+        # set up preconditions
+        self.mock_api().status.return_value = {"dps": {"1": "CHANGED"}}
+        self.subject._cached_state = {"1": "UNCHANGED", "updated_at": 123}
+
+        # call the function under test
+        self.subject._refresh_cached_state()
+
+        # Did it call the API as expected?
+        self.mock_api().status.assert_called_once()
+        # Did it update the cached state?
+        self.assertDictEqual(
+            self.subject._cached_state,
+            {"1": "CHANGED"} | self.subject._cached_state,
+        )
+        # Did it update the timestamp on the cached state?
+        self.assertAlmostEqual(
+            self.subject._cached_state["updated_at"],
+            time(),
+            delta=2,
+        )
+
+    def test_send_payload(self):
+        # set up preconditions
+        self.subject._pending_updates = {
+            "1": {"value": "sample", "updated_at": time() - 2, "sent": False},
+        }
+
+        # call the function under test
+        self.subject._send_payload("PAYLOAD")
+
+        # did it send what it was asked?
+        self.mock_api().send.assert_called_once_with("PAYLOAD")
+        # did it mark the pending updates as sent?
+        self.assertTrue(self.subject._pending_updates["1"]["sent"])
+        # did it update the time on the pending updates?
+        self.assertAlmostEqual(
+            self.subject._pending_updates["1"]["updated_at"],
+            time(),
+            delta=2,
+        )
+        # did it lock and unlock when sending
+        self.subject._lock.acquire.assert_called_once()
+        self.subject._lock.release.assert_called_once()
+
+    def test_actually_start(self):
+        # Set up the preconditions
+        self.subject.receive_loop = Mock()
+        self.subject.receive_loop.return_value = "LOOP"
+        self.hass().bus.async_listen_once.return_value = "LISTENER"
+        self.subject._running = False
+
+        # run the function under test
+        self.subject.actually_start()
+
+        # did it register a listener for EVENT_HOMEASSISTANT_STOP?
+        self.hass().bus.async_listen_once.assert_called_once_with(
+            EVENT_HOMEASSISTANT_STOP, self.subject.async_stop
+        )
+        self.assertEqual(self.subject._shutdown_listener, "LISTENER")
+        # did it set the running flag?
+        self.assertTrue(self.subject._running)
+        # did it schedule the loop?
+        self.hass().async_create_task.assert_called_once_with("LOOP")
+
+    def test_start_starts_when_ha_running(self):
+        # Set up preconditions
+        self.hass().is_running = True
+        self.hass().is_stopping = False
+        listener = Mock()
+        self.subject._startup_listener = listener
+        self.subject.actually_start = Mock()
+
+        # Call the function under test
+        self.subject.start()
+
+        # Did it actually start?
+        self.subject.actually_start.assert_called_once()
+        # Did it cancel the startup listener?
+        self.assertIsNone(self.subject._startup_listener)
+        listener.assert_called_once()
+
+    def test_start_schedules_for_later_when_ha_starting(self):
+        # Set up preconditions
+        self.hass().is_running = False
+        self.hass().is_stopping = False
+        self.hass().bus.async_listen_once.return_value = "LISTENER"
+        self.subject.actually_start = Mock()
+
+        # Call the function under test
+        self.subject.start()
+
+        # Did it avoid actually starting?
+        self.subject.actually_start.assert_not_called()
+        # Did it register a listener?
+        self.assertEqual(self.subject._startup_listener, "LISTENER")
+        self.hass().bus.async_listen_once.assert_called_once_with(
+            EVENT_HOMEASSISTANT_STARTED, self.subject.actually_start
+        )
+
+    def test_start_does_nothing_when_ha_stopping(self):
+        # Set up preconditions
+        self.hass().is_running = True
+        self.hass().is_stopping = True
+        self.subject.actually_start = Mock()
+
+        # Call the function under test
+        self.subject.start()
+
+        # Did it avoid actually starting?
+        self.subject.actually_start.assert_not_called()
+        # Did it avoid registering a listener?
+        self.hass().bus.async_listen_once.assert_not_called()
+        self.assertIsNone(self.subject._startup_listener)
+
+    async def test_async_stop(self):
+        # Set up preconditions
+        listener = Mock()
+        self.subject._refresh_task = None
+        self.subject._shutdown_listener = listener
+        self.subject._children = [1, 2, 3]
+
+        # Call the function under test
+        await self.subject.async_stop()
+
+        # Was the shutdown listener cancelled?
+        listener.assert_called_once()
+        self.assertIsNone(self.subject._shutdown_listener)
+        # Were the child entities cleared?
+        self.assertEqual(self.subject._children, [])
+        # Did it wait for the refresh task to finish then clear it?
+        # This doesn't work because AsyncMock only mocks awaitable method calls
+        # but we want an awaitable object
+        # refresh.assert_awaited_once()
+        self.assertIsNone(self.subject._refresh_task)
+
+    async def test_async_stop_when_not_running(self):
+        # Set up preconditions
+        self._refresh_task = None
+        self.subject._shutdown_listener = None
+        self.subject._children = []
+
+        # Call the function under test
+        await self.subject.async_stop()
+
+        # Was the shutdown listener left empty?
+        self.assertIsNone(self.subject._shutdown_listener)
+        # Were the child entities cleared?
+        self.assertEqual(self.subject._children, [])
+        # Was the refresh task left empty?
+        self.assertIsNone(self.subject._refresh_task)
+
+    def test_register_first_entity_ha_running(self):
+        # Set up preconditions
+        self.subject._children = []
+        self.subject._running = False
+        self.subject._startup_listener = None
+        self.subject.start = Mock()
+
+        # Call the function under test
+        self.subject.register_entity("Entity")
+
+        # Was the entity added to the list?
+        self.assertEqual(self.subject._children, ["Entity"])
+
+        # Did we start the loop?
+        self.subject.start.assert_called_once()
+
+    def test_register_subsequent_entity_ha_running(self):
+        # Set up preconditions
+        self.subject._children = ["First"]
+        self.subject._running = True
+        self.subject._startup_listener = None
+        self.subject.start = Mock()
+
+        # Call the function under test
+        self.subject.register_entity("Entity")
+
+        # Was the entity added to the list?
+        self.assertCountEqual(self.subject._children, ["First", "Entity"])
+
+        # Did we avoid restarting the loop?
+        self.subject.start.assert_not_called()
+
+    def test_register_subsequent_entity_ha_starting(self):
+        # Set up preconditions
+        self.subject._children = ["First"]
+        self.subject._running = False
+        self.subject._startup_listener = Mock()
+        self.subject.start = Mock()
+
+        # Call the function under test
+        self.subject.register_entity("Entity")
+
+        # Was the entity added to the list?
+        self.assertCountEqual(self.subject._children, ["First", "Entity"])
+        # Did we avoid restarting the loop?
+        self.subject.start.assert_not_called()
+
+    async def test_unregister_one_of_many_entities(self):
+        # Set up preconditions
+        self.subject._children = ["First", "Second"]
+        self.subject.async_stop = AsyncMock()
+
+        # Call the function under test
+        await self.subject.async_unregister_entity("First")
+
+        # Was the entity removed from the list?
+        self.assertCountEqual(self.subject._children, ["Second"])
+        # Is the loop still running?
+        self.subject.async_stop.assert_not_called()
+
+    async def test_unregister_last_entity(self):
+        # Set up preconditions
+        self.subject._children = ["Last"]
+        self.subject.async_stop = AsyncMock()
+
+        # Call the function under test
+        await self.subject.async_unregister_entity("Last")
+
+        # Was the entity removed from the list?
+        self.assertEqual(self.subject._children, [])
+        # Was the loop stopped?
+        self.subject.async_stop.assert_called_once()