4
0
Эх сурвалжийг харах

Refactor cloud functionality into its own file.

In future we might like to use cloud to add extra functionality such
as cloud based camera feeds and vacuum maps. This will be easier if
cloud functions are not bundled in with config_flow.

- add QueryThingsDataModel, so it can be logged along with product id
  to make reporting new devices with the detail we need easier.
Jason Rumney 1 жил өмнө
parent
commit
c95d131e60

+ 241 - 0
custom_components/tuya_local/cloud.py

@@ -0,0 +1,241 @@
+import logging
+from typing import Any
+
+from homeassistant.core import HomeAssistant
+
+from tuya_sharing import (
+    CustomerDevice,
+    LoginControl,
+    Manager,
+    SharingDeviceListener,
+    SharingTokenListener,
+)
+
+from .const import (
+    CONF_ENDPOINT,
+    CONF_TERMINAL_ID,
+    DOMAIN,
+    TUYA_CLIENT_ID,
+    TUYA_RESPONSE_QR_CODE,
+    TUYA_RESPONSE_RESULT,
+    TUYA_RESPONSE_SUCCESS,
+    TUYA_SCHEMA,
+)
+
+_LOGGER = logging.getLogger(__name__)
+
+HUB_CATEGORIES = [
+    "wgsxj",  # Gateway camera
+    "lyqwg",  # Router
+    "bywg",  # IoT edge gateway
+    "zigbee",  # Gateway
+    "wg2",  # Gateway
+    "dgnzk",  # Multi-function controller
+    "videohub",  # Videohub
+    "xnwg",  # Virtual gateway
+    "qtyycp",  # Voice gateway composite solution
+    "alexa_yywg",  # Gateway with Alexa
+    "gywg",  # Industrial gateway
+    "cnwg",  # Energy gateway
+    "wnykq",  # Smart IR
+]
+
+
+class Cloud:
+    """Optional Tuya cloud interface for getting device information."""
+
+    def __init__(self, hass: HomeAssistant):
+        self.__login_control = LoginControl()
+        self.__authentication = {}
+        self.__user_code = None
+        self.__qr_code = None
+        self.__hass = hass
+        self.__error_code = None
+        self.__error_msg = None
+
+    async def async_get_qr_code(self, user_code: str | None = None) -> bool:
+        """Get QR code from Tuya server for user code authentication."""
+        if not user_code:
+            user_code = self.__user_code
+            if not user_code:
+                _LOGGER.error("Cannot get QR code without a user code")
+                return False, {TUYA_RESPONSE_MSG: "QR code requires a user code"}
+
+        response = await self.__hass.async_add_executor_job(
+            self.__login_control.qr_code,
+            TUYA_CLIENT_ID,
+            TUYA_SCHEMA,
+            user_code,
+        )
+        if success := response.get(TUYA_RESPONSE_SUCCESS, False):
+            self.__user_code = user_code
+            self.__qr_code = response[TUYA_RESPONSE_RESULT][TUYA_RESPONSE_QR_CODE]
+            return self.__qr_code
+
+        self.__error_code = response.get(TUYA_RESPONSE_CODE, {})
+        self.__error_msg = response.get(TUYA_RESPONSE_MSG, "Unknown error")
+
+        return False
+
+    async def async_login(self) -> bool:
+        """Login to the Tuya cloud."""
+        if not self.__user_code or not self.__qr_code:
+            _LOGGER.warn("Login attempted without successful QR scan")
+            return False, {}
+
+        success, info = await self.__hass.async_add_executor_job(
+            self.__login_control.login_result,
+            self.__qr_code,
+            TUYA_CLIENT_ID,
+            self.__user_code,
+        )
+        if success:
+            self.__authentication = {
+                "user_code": self.__user_code,
+                "terminal_id": info[CONF_TERMINAL_ID],
+                "endpoint": info[CONF_ENDPOINT],
+                "token_info": {
+                    "t": info["t"],
+                    "uid": info["uid"],
+                    "expire_time": info["expire_time"],
+                    "access_token": info["access_token"],
+                    "refresh_token": info["refresh_token"],
+                },
+            }
+        else:
+            self.__error_code = response.get(TUYA_RESPONSE_CODE, {})
+            self.__error_msg = response.get(TUYA_RESPONSE_MSG, "Unknown error")
+
+        return success
+
+    async def async_get_devices(self) -> dict[str, Any]:
+        """Get all devices associated with the account."""
+        token_listener = TokenListener(self.__hass)
+        manager = Manager(
+            TUYA_CLIENT_ID,
+            self.__authentication["user_code"],
+            self.__authentication["terminal_id"],
+            self.__authentication["endpoint"],
+            self.__authentication["token_info"],
+            token_listener,
+        )
+
+        listener = DeviceListener(self.__hass, manager)
+        manager.add_device_listener(listener)
+
+        # Get all devices from Tuya cloud
+        await self.__hass.async_add_executor_job(manager.update_device_cache)
+
+        # Register known device IDs
+        cloud_devices = {}
+        domain_data = self.__hass.data.get(DOMAIN)
+        for device in manager.device_map.values():
+            cloud_device = {
+                "category": device.category,
+                "id": device.id,
+                "ip": device.ip,
+                CONF_LOCAL_KEY: device.local_key
+                if hassattr(device, CONF_LOCAL_KEY)
+                else "",
+                "name": device.name,
+                "node_id": device.node_id if hasattr(device, "node_id") else "",
+                "online": device.online,
+                "product_id": device.product_id,
+                "product_name": device.product_name,
+                "uid": device.uid,
+                "uuid": device.uuid,
+                "support_local": device.support_local,
+                CONF_DEVICE_CID: None,
+                "version": None,
+                "is_hub": device.caetgory in HUB_CATEGORIES or device.local_key == "",
+            }
+            _LOGGER.debug("Found device: {cloud_device}")
+
+            existing_id = domain_data.get(cloud_device["id"]) if domain_data else None
+            existing_uuid = (
+                domain_data.get(cloud_device["uuid"]) if domain_data else None
+            )
+            existing = existing_id or existing_uuid
+            cloud_device["exists"] = existing and existing.get("device")
+            cloud_devices[cloud_device["id"]] = cloud_device
+
+        return cloud_devices
+
+    async def async_get_datamodel(self, device_id) -> dict[str, Any] | None:
+        """Get the data model for the specified device (QueryThingsDataModel)."""
+        token_listener = TokenListener(self.__hass)
+        manager = Manager(
+            TUYA_CLIENT_ID,
+            self.__authentication["user_code"],
+            self.__authentication["terminal_id"],
+            self.__authentication["endpoint"],
+            self.__authentication["token_info"],
+            token_listener,
+        )
+        response = await self.__hass.async_add_executor_job(
+            manager.customer_api.get,
+            f"/v2.0/cloud/thing/{device_id}/model",
+        )
+        return response
+
+    @property
+    def is_authenticated(self) -> bool:
+        """Is the cloud account authenticated?"""
+        return True if self.__authentication else False
+
+    @property
+    def last_error(self) -> dict[str, Any] | None:
+        """The last cloud error code and message, if any."""
+        if self.__error_code is not None:
+            return {
+                TUYA_RESPONSE_MSG: self.__error_msg,
+                TUYA_RESPONSE_CODE: self.__error_code,
+            }
+
+
+class DeviceListener(SharingDeviceListener):
+    """Device update listener."""
+
+    def __init__(
+        self,
+        hass: HomeAssistant,
+        manager: Manager,
+    ):
+        self.__hass = hass
+        self._manager = manager
+
+    def update_device(self, device: CustomerDevice) -> None:
+        """Device status has updated."""
+        _LOGGER.debug(
+            "Received update for device %s: %s",
+            device.id,
+            self._manager.device_map[device.id].status,
+        )
+
+    def add_device(self, device: CustomerDevice) -> None:
+        """A new device has been added."""
+        _LOGGER.device(
+            "Received add device %s: %s",
+            device.id,
+            self._manager.device_map[device.id].status,
+        )
+
+    def remove_device(self, device_id: str) -> None:
+        """A device has been removed."""
+        _LOGGER.debug(
+            "Received remove device %s: %s",
+            device_id,
+            self._manager.device_map[device_id].status,
+        )
+
+
+class TokenListener(SharingTokenListener):
+    """Listener for upstream token updates.
+    This is only needed to get some debug output when tokens are refreshed."""
+
+    def __init__(self, hass: HomeAssistant):
+        self.__hass = hass
+
+    def update_token(self, token_info: dict[str, Any]) -> None:
+        """Update the token information."""
+        _LOGGER.debug("Token updated")

+ 31 - 215
custom_components/tuya_local/config_flow.py

@@ -18,34 +18,19 @@ from homeassistant.helpers.selector import (
     SelectSelectorConfig,
     SelectSelectorMode,
 )
-from tuya_sharing import (
-    CustomerDevice,
-    LoginControl,
-    Manager,
-    SharingDeviceListener,
-    SharingTokenListener,
-)
 
 from . import DOMAIN
+from .cloud import Cloud
 from .const import (
     API_PROTOCOL_VERSIONS,
     CONF_DEVICE_CID,
     CONF_DEVICE_ID,
-    CONF_ENDPOINT,
     CONF_LOCAL_KEY,
     CONF_POLL_ONLY,
     CONF_PROTOCOL_VERSION,
-    CONF_TERMINAL_ID,
     CONF_TYPE,
     CONF_USER_CODE,
     DATA_STORE,
-    TUYA_CLIENT_ID,
-    TUYA_RESPONSE_CODE,
-    TUYA_RESPONSE_MSG,
-    TUYA_RESPONSE_QR_CODE,
-    TUYA_RESPONSE_RESULT,
-    TUYA_RESPONSE_SUCCESS,
-    TUYA_SCHEMA,
 )
 from .device import TuyaLocalDevice
 from .helpers.config import get_device_id
@@ -54,22 +39,6 @@ from .helpers.log import log_json
 
 _LOGGER = logging.getLogger(__name__)
 
-HUB_CATEGORIES = [
-    "wgsxj",  # Gateway camera
-    "lyqwg",  # Router
-    "bywg",  # IoT edge gateway
-    "zigbee",  # Gateway
-    "wg2",  # Gateway
-    "dgnzk",  # Multi-function controller
-    "videohub",  # Videohub
-    "xnwg",  # Virtual gateway
-    "qtyycp",  # Voice gateway composite solution
-    "alexa_yywg",  # Gateway with Alexa
-    "gywg",  # Industrial gateway
-    "cnwg",  # Energy gateway
-    "wnykq",  # Smart IR
-]
-
 
 class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
     VERSION = 13
@@ -78,17 +47,13 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
     device = None
     data = {}
 
-    __user_code: str
-    __qr_code: str
-    __authentication: dict
-    __cloud_devices: dict
-    __cloud_device: dict
+    __qr_code: str | None = None
+    __cloud_devices: dict[str, Any] = {}
+    __cloud_device: dict[str, Any] | None = None
 
     def __init__(self) -> None:
         """Initialize the config flow."""
-        self.__login_control = LoginControl()
-        self.__cloud_devices = {}
-        self.__cloud_device = None
+        self.cloud = Cloud(self.hass)
 
     async def async_step_user(self, user_input=None):
         errors = {}
@@ -97,23 +62,20 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
             self.hass.data[DOMAIN] = {}
         if self.hass.data[DOMAIN].get(DATA_STORE) is None:
             self.hass.data[DOMAIN][DATA_STORE] = {}
-        self.__authentication = self.hass.data[DOMAIN][DATA_STORE].get(
-            "authentication", None
-        )
 
         if user_input is not None:
             if user_input["setup_mode"] == "cloud":
                 try:
-                    if self.__authentication is not None:
-                        self.__cloud_devices = await self.load_device_info()
-                        return await self.async_step_choose_device(None)
+                    if self.cloud.is_authenticated:
+                        self.__cloud_devices = await self.cloud.async_get_devices()
+                        return await self.async_step_choose_device()
                 except Exception as e:
                     # Re-authentication is needed.
                     _LOGGER.warning("Connection test failed with %s %s", type(e), e)
                     _LOGGER.warning("Re-authentication is required.")
-                return await self.async_step_cloud(None)
+                return await self.async_step_cloud()
             if user_input["setup_mode"] == "manual":
-                return await self.async_step_local(None)
+                return await self.async_step_local()
 
         # Build form
         fields: OrderedDict[vol.Marker, Any] = OrderedDict()
@@ -140,17 +102,14 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
         placeholders = {}
 
         if user_input is not None:
-            success, response = await self.__async_get_qr_code(
-                user_input[CONF_USER_CODE]
-            )
-            if success:
-                return await self.async_step_scan(None)
+            response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
+            if response:
+                self.__qr_code = response
+                return await self.async_step_scan()
 
             errors["base"] = "login_error"
-            placeholders = {
-                TUYA_RESPONSE_MSG: response.get(TUYA_RESPONSE_MSG, "Unknown error"),
-                TUYA_RESPONSE_CODE: response.get(TUYA_RESPONSE_CODE, "0"),
-            }
+            placeholders = self.cloud.last_error
+
         else:
             user_input = {}
 
@@ -187,18 +146,17 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                 ),
             )
 
-        ret, info = await self.hass.async_add_executor_job(
-            self.__login_control.login_result,
-            self.__qr_code,
-            TUYA_CLIENT_ID,
-            self.__user_code,
-        )
-        if not ret:
+        if not await self.cloud.async_login():
             # Try to get a new QR code on failure
-            await self.__async_get_qr_code(self.__user_code)
+            response = await self.cloud.async_get_qr_code()
+            errors["base"] = "login_error"
+            placeholders = self.cloud.last_error
+            if response:
+                self.__qr_code = response
+
             return self.async_show_form(
                 step_id="scan",
-                errors={"base": "login_error"},
+                errors=errors,
                 data_schema=vol.Schema(
                     {
                         vol.Optional("QR"): QrCodeSelector(
@@ -210,86 +168,12 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                         )
                     }
                 ),
-                description_placeholders={
-                    TUYA_RESPONSE_MSG: info.get(TUYA_RESPONSE_MSG, "Unknown error"),
-                    TUYA_RESPONSE_CODE: info.get(TUYA_RESPONSE_CODE, 0),
-                },
-            )
-
-        # Now that we have successfully logged in we can query for devices for the account.
-        self.__authentication = {
-            "user_code": info[CONF_TERMINAL_ID],
-            "terminal_id": info[CONF_TERMINAL_ID],
-            "endpoint": info[CONF_ENDPOINT],
-            "token_info": {
-                "t": info["t"],
-                "uid": info["uid"],
-                "expire_time": info["expire_time"],
-                "access_token": info["access_token"],
-                "refresh_token": info["refresh_token"],
-            },
-        }
-        self.hass.data[DOMAIN][DATA_STORE]["authentication"] = self.__authentication
-        _LOGGER.debug(f"domain_data is {self.hass.data[DOMAIN]}")
-
-        self.__cloud_devices = await self.load_device_info()
-
-        return await self.async_step_choose_device(None)
-
-    async def load_device_info(self) -> dict:
-        token_listener = TokenListener(self.hass)
-        manager = Manager(
-            TUYA_CLIENT_ID,
-            self.__authentication["user_code"],
-            self.__authentication["terminal_id"],
-            self.__authentication["endpoint"],
-            self.__authentication["token_info"],
-            token_listener,
-        )
-
-        listener = DeviceListener(self.hass, manager)
-        manager.add_device_listener(listener)
-
-        # Get all devices from Tuya
-        await self.hass.async_add_executor_job(manager.update_device_cache)
-
-        # Register known device IDs
-        cloud_devices = {}
-        domain_data = self.hass.data.get(DOMAIN)
-        for device in manager.device_map.values():
-            cloud_device = {
-                # TODO - Use constants throughout
-                "category": device.category,
-                "id": device.id,
-                "ip": device.ip,  # This will be the WAN IP address so not usable.
-                CONF_LOCAL_KEY: device.local_key
-                if hasattr(device, CONF_LOCAL_KEY)
-                else "",
-                "name": device.name,
-                "node_id": device.node_id if hasattr(device, "node_id") else "",
-                "online": device.online,
-                "product_id": device.product_id,
-                "product_name": device.product_name,
-                "uid": device.uid,
-                "uuid": device.uuid,
-                "support_local": device.support_local,  # What does this mean?
-                CONF_DEVICE_CID: None,
-                "version": None,
-            }
-            _LOGGER.debug(f"Found device: {cloud_device}")
-
-            existing_id = domain_data.get(cloud_device["id"]) if domain_data else None
-            existing_uuid = (
-                domain_data.get(cloud_device["uuid"]) if domain_data else None
+                description_placeholders=placeholders,
             )
-            existing = existing_id or existing_uuid
-            if existing and existing.get("device"):
-                cloud_device["exists"] = True
 
-            _LOGGER.debug(f"Adding device: {cloud_device['id']}")
-            cloud_devices[cloud_device["id"]] = cloud_device
+        self.__cloud_devices = await self.cloud.async_get_devices()
 
-        return cloud_devices
+        return await self.async_step_choose_device()
 
     async def async_step_choose_device(self, user_input=None):
         errors = {}
@@ -301,7 +185,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                 if user_input["hub_id"] == "None":
                     device_choice["ip"] = ""
                     self.__cloud_device = device_choice
-                    return await self.async_step_search(None)
+                    return await self.async_step_search()
                 else:
                     # Show error if user selected a hub.
                     errors["base"] = "does_not_need_hub"
@@ -315,7 +199,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                     hub_choice[CONF_DEVICE_CID] = device_choice["uuid"]
                     hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
                     self.__cloud_device = hub_choice
-                    return await self.async_step_search(None)
+                    return await self.async_step_search()
                 else:
                     # Show error if user did not select a hub.
                     errors["base"] = "needs_hub"
@@ -354,10 +238,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
         hub_list.append(SelectOptionDict(value="None", label="None"))
         for key in self.__cloud_devices.keys():
             hub_entry = self.__cloud_devices[key]
-            if (
-                hub_entry[CONF_LOCAL_KEY] == ""
-                or hub_entry["category"] in HUB_CATEGORIES
-            ):
+            if hub_entry["is_hub"]:
                 hub_list.append(
                     SelectOptionDict(
                         value=key,
@@ -406,7 +287,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                 self.__cloud_device["version"] = local_device["version"]
             else:
                 _LOGGER.warning(f"Could not find device: {self.__cloud_device['id']}")
-            return await self.async_step_local(None)
+            return await self.async_step_local()
 
         return self.async_show_form(
             step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
@@ -530,19 +411,6 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
     def async_get_options_flow(config_entry):
         return OptionsFlowHandler(config_entry)
 
-    async def __async_get_qr_code(self, user_code: str) -> tuple[bool, dict[str, Any]]:
-        """Get the QR code."""
-        response = await self.hass.async_add_executor_job(
-            self.__login_control.qr_code,
-            TUYA_CLIENT_ID,
-            TUYA_SCHEMA,
-            user_code,
-        )
-        if success := response.get(TUYA_RESPONSE_SUCCESS, False):
-            self.__user_code = user_code
-            self.__qr_code = response[TUYA_RESPONSE_RESULT][TUYA_RESPONSE_QR_CODE]
-        return success, response
-
 
 class OptionsFlowHandler(config_entries.OptionsFlow):
     def __init__(self, config_entry):
@@ -643,55 +511,3 @@ async def async_test_connection(config: dict, hass: HomeAssistant):
 
 def scan_for_device(id):
     return tinytuya.find_device(dev_id=id)
-
-
-class DeviceListener(SharingDeviceListener):
-    """Device Update Listener."""
-
-    def __init__(
-        self,
-        hass: HomeAssistant,
-        manager: Manager,
-    ) -> None:
-        """Init DeviceListener."""
-        self.hass = hass
-        self.manager = manager
-
-    def update_device(self, device: CustomerDevice) -> None:
-        """Update device status."""
-        _LOGGER.debug(
-            "Received update for device %s: %s",
-            device.id,
-            self.manager.device_map[device.id].status,
-        )
-
-    def add_device(self, device: CustomerDevice) -> None:
-        """Add device added listener."""
-        _LOGGER.debug(
-            "Received add device %s: %s",
-            device.id,
-            self.manager.device_map[device.id].status,
-        )
-
-    def remove_device(self, device_id: str) -> None:
-        """Add device removed listener."""
-        _LOGGER.debug(
-            "Received remove device %s: %s",
-            device_id,
-            self.manager.device_map[device_id].status,
-        )
-
-
-class TokenListener(SharingTokenListener):
-    """Token listener for upstream token updates."""
-
-    def __init__(
-        self,
-        hass: HomeAssistant,
-    ) -> None:
-        """Init TokenListener."""
-        self.hass = hass
-
-    def update_token(self, token_info: dict[str, Any]) -> None:
-        """Update token info in config entry."""
-        _LOGGER.debug("update_token")