Răsfoiți Sursa

fix(device): delay api creation until async functions

Cannot call async tinytuya device creation until we are in async code.
Jason Rumney 5 luni în urmă
părinte
comite
6f67fca1ed
1 a modificat fișierele cu 69 adăugiri și 40 ștergeri
  1. 69 40
      custom_components/tuya_local/device.py

+ 69 - 40
custom_components/tuya_local/device.py

@@ -74,40 +74,12 @@ class TuyaLocalDevice(object):
         self._api_protocol_working = False
         self._api_working_protocol_failures = 0
         self.dev_cid = dev_cid
-        try:
-            if dev_cid:
-                if hass.data[DOMAIN].get(dev_id) and name != "Test":
-                    parent = hass.data[DOMAIN][dev_id]["tuyadevice"]
-                else:
-                    parent = tinytuya.Device(dev_id, address, local_key)
-                    if name != "Test":
-                        hass.data[DOMAIN][dev_id] = {"tuyadevice": parent}
-                self._api = await tinytuya.DeviceAsync.create(
-                    dev_cid,
-                    cid=dev_cid,
-                    parent=parent,
-                )
-            else:
-                if hass.data[DOMAIN].get(dev_id) and name != "Test":
-                    self._api = hass.data[DOMAIN][dev_id]["tuyadevice"]
-                else:
-                    self._api = tinytuya.Device(dev_id, address, local_key)
-                    if name != "Test":
-                        hass.data[DOMAIN][dev_id] = {"tuyadevice": self._api}
-        except Exception as e:
-            _LOGGER.error(
-                "%s: %s while initialising device %s",
-                type(e).__name__,
-                e,
-                dev_id,
-            )
-            raise e
-
-        # we handle retries at a higher level so we can rotate protocol version
-        self._api.set_socketRetryLimit(1)
-        if self._api.parent:
-            # Retries cause problems for other children of the parent device
-            self._api.parent.set_socketRetryLimit(1)
+        self.address = address
+        self.dev_id = dev_id
+        self.local_key = local_key
+        self._api = None
+        if hass.data[DOMAIN].get(dev_id) and name != "Test":
+            self._api = hass.data[DOMAIN][dev_id]["tuyadevice"]
 
         self._refresh_task = None
         self._protocol_configured = protocol_version
@@ -141,7 +113,7 @@ class TuyaLocalDevice(object):
     @property
     def unique_id(self):
         """Return the unique id for this device (the dev_id or dev_cid)."""
-        return self.dev_cid or self._api.id
+        return self.dev_cid or self.dev_id
 
     @property
     def device_info(self):
@@ -193,7 +165,9 @@ class TuyaLocalDevice(object):
             await self._refresh_task
         _LOGGER.debug("Monitor loop for %s stopped", self.name)
         self._refresh_task = None
-        await self._api.close()
+        if self._api:
+            await self._api.close()
+            self._api = None
 
     def register_entity(self, entity):
         # If this is the first child entity to register, and HA is still
@@ -267,14 +241,53 @@ class TuyaLocalDevice(object):
 
     def pause(self):
         self._temporary_poll = True
-        self._api.set_socketPersistent(False)
-        if self._api.parent:
-            self._api.parent.set_socketPersistent(False)
+        if self._api:
+            self._api.set_socketPersistent(False)
+            if self._api.parent:
+                self._api.parent.set_socketPersistent(False)
 
     def resume(self):
         self._temporary_poll = False
 
-    async def async_receive(self):
+    async def async_ensure_connection(self):
+        """Ensure the device is connected and has returned state."""
+        if self._api is None:
+            try:
+                if self.dev_cid:
+                    if self._hass.data[DOMAIN].get(self.dev_id) and self.name != "Test":
+                        parent = self._hass.data[DOMAIN][self.dev_id]["tuyadevice"]
+                    else:
+                        parent = None
+                        if self.name != "Test":
+                            self._hass.data[DOMAIN][self.dev_id] = {
+                                "tuyadevice": parent
+                            }
+                    self._api = await tinytuya.DeviceAsync.create(
+                        self.dev_cid,
+                        cid=self.dev_cid,
+                        parent=parent,
+                    )
+                else:
+                    if self._hass.data[DOMAIN].get(self.dev_id) and self.name != "Test":
+                        self._api = self._hass.data[DOMAIN][self.dev_id]["tuyadevice"]
+                    else:
+                        self._api = await tinytuya.DeviceAsync.create(
+                            self.dev_id, self.address, self.local_key
+                        )
+                        if self.name != "Test":
+                            self._hass.data[DOMAIN][self.dev_id] = {
+                                "tuyadevice": self._api
+                            }
+            except Exception as e:
+                _LOGGER.error(
+                    "%s: %s while initialising device %s",
+                    type(e).__name__,
+                    e,
+                    self.dev_id,
+                )
+                raise e
+
+      async def async_receive(self):
         """Receive messages from a persistent connection asynchronously."""
         # If we didn't yet get any state from the device, we may need to
         # negotiate the protocol before making the connection persistent
@@ -283,6 +296,14 @@ class TuyaLocalDevice(object):
         # all dps updated
         dps_updated = False
 
+        await self.async_ensure_connection()
+
+        # we handle retries at a higher level so we can rotate protocol version
+        self._api.set_socketRetryLimit(1)
+        if self._api.parent:
+            # Retries cause problems for other children of the parent device
+            self._api.parent.set_socketRetryLimit(1)
+
         self._api.set_socketPersistent(persist)
         if self._api.parent:
             self._api.parent.set_socketPersistent(persist)
@@ -386,6 +407,7 @@ class TuyaLocalDevice(object):
         self._product_ids.append(product_id)
 
     async def async_possible_types(self):
+        await self.async_ensure_connection()
         cached_state = self._get_cached_state()
         if len(cached_state) <= 1:
             # in case of device22 devices, we need to poll them with a dp
@@ -473,6 +495,7 @@ class TuyaLocalDevice(object):
         self._last_connection = 0
 
     async def _refresh_cached_state(self):
+        await self.async_ensure_connection()
         new_state = await self._api.status()
         if new_state and "Err" not in new_state:
             self._cached_state = self._cached_state | new_state.get("dps", {})
@@ -569,6 +592,7 @@ class TuyaLocalDevice(object):
         )
 
     async def _set_values(self, properties):
+        await self.async_ensure_connection()
         try:
             self._lock.acquire()
             await self._api.set_multiple_values(properties, nowait=True)
@@ -583,6 +607,7 @@ class TuyaLocalDevice(object):
             self._lock.release()
 
     async def _retry_on_failed_connection(self, func, error_message):
+        await self.async_ensure_connection()
         if self._api_protocol_version_index is None:
             self._rotate_api_protocol_version()
         auto = (self._protocol_configured == "auto") and (
@@ -683,6 +708,10 @@ class TuyaLocalDevice(object):
             self.name,
             new_version,
         )
+        # If we don't have a connection, don't set the version yet
+        if not self._api:
+            return
+
         # Only enable tinytuya's auto-detect when using 3.22
         if new_version == 3.22:
             new_version = 3.3