Jelajahi Sumber

config_flow: use product_ids from cloud flow in device matching

Since we can now collect product ids from cloud and local discovery
(currently only used during cloud flow), we can make use of this in
device matching.

If product id matches, return 100% for the match, provided there are
no conflicts in dp types.  Do this even if no dps are available from
the device, to help with matching those devices that don't return much
or any data.
Jason Rumney 1 tahun lalu
induk
melakukan
e42272bdb7

+ 14 - 1
custom_components/tuya_local/config_flow.py

@@ -326,6 +326,16 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
             self.device = await async_test_connection(user_input, self.hass)
             self.device = await async_test_connection(user_input, self.hass)
             if self.device:
             if self.device:
                 self.data = user_input
                 self.data = user_input
+                if self.__cloud_device:
+                    if self.__cloud_device.get("product_id"):
+                        self.device.set_detected_product_id(
+                            self.__cloud_device["product_id"]
+                        )
+                    if self.__cloud_device.get("local_product_id"):
+                        self.device.set_detected_product_id(
+                            self.__cloud_device["local_product_id"]
+                        )
+
                 return await self.async_step_select_type()
                 return await self.async_step_select_type()
             else:
             else:
                 errors["base"] = "connection"
                 errors["base"] = "connection"
@@ -366,7 +376,10 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
 
 
         async for type in self.device.async_possible_types():
         async for type in self.device.async_possible_types():
             types.append(type.config_type)
             types.append(type.config_type)
-            q = type.match_quality(self.device._get_cached_state())
+            q = type.match_quality(
+                self.device._get_cached_state(),
+                self.device._product_ids,
+            )
             if q > best_match:
             if q > best_match:
                 best_match = q
                 best_match = q
                 best_matching_type = type.config_type
                 best_matching_type = type.config_type

+ 6 - 1
custom_components/tuya_local/device.py

@@ -61,6 +61,7 @@ class TuyaLocalDevice(object):
         self._name = name
         self._name = name
         self._children = []
         self._children = []
         self._force_dps = []
         self._force_dps = []
+        self._product_ids = []
         self._running = False
         self._running = False
         self._shutdown_listener = None
         self._shutdown_listener = None
         self._startup_listener = None
         self._startup_listener = None
@@ -368,6 +369,9 @@ class TuyaLocalDevice(object):
         if self._api.parent:
         if self._api.parent:
             self._api.parent.set_socketPersistent(False)
             self._api.parent.set_socketPersistent(False)
 
 
+    def set_detected_product_id(self, product_id):
+        self._product_ids.append(product_id)
+
     async def async_possible_types(self):
     async def async_possible_types(self):
         cached_state = self._get_cached_state()
         cached_state = self._get_cached_state()
         if len(cached_state) <= 1:
         if len(cached_state) <= 1:
@@ -395,6 +399,7 @@ class TuyaLocalDevice(object):
         for matched in await self._hass.async_add_executor_job(
         for matched in await self._hass.async_add_executor_job(
             possible_matches,
             possible_matches,
             cached_state,
             cached_state,
+            self._product_ids,
         ):
         ):
             await asyncio.sleep(0)
             await asyncio.sleep(0)
             yield matched
             yield matched
@@ -404,7 +409,7 @@ class TuyaLocalDevice(object):
         best_quality = 0
         best_quality = 0
         cached_state = self._get_cached_state()
         cached_state = self._get_cached_state()
         async for config in self.async_possible_types():
         async for config in self.async_possible_types():
-            quality = config.match_quality(cached_state)
+            quality = config.match_quality(cached_state, self._product_ids)
             _LOGGER.info(
             _LOGGER.info(
                 "%s considering %s with quality %s",
                 "%s considering %s with quality %s",
                 self.name,
                 self.name,

+ 30 - 9
custom_components/tuya_local/helpers/device_config.py

@@ -141,7 +141,15 @@ class TuyaDeviceConfig:
         for e in self.secondary_entities():
         for e in self.secondary_entities():
             yield e
             yield e
 
 
-    def matches(self, dps):
+    def matches(self, dps, product_ids):
+        """Determine whether this config matches the provided dps map or
+        product ids."""
+        product_match = False
+        if product_ids:
+            for p in self._config.get("products", []):
+                if p.get("id", "MISSING_ID!?!") in product_ids:
+                    product_match = True
+
         required_dps = self._get_required_dps()
         required_dps = self._get_required_dps()
 
 
         missing_dps = [dp for dp in required_dps if dp.id not in dps.keys()]
         missing_dps = [dp for dp in required_dps if dp.id not in dps.keys()]
@@ -163,8 +171,14 @@ class TuyaDeviceConfig:
                 self.name,
                 self.name,
                 [{dp.id: dp.type.__name__} for dp in incorrect_type_dps],
                 [{dp.id: dp.type.__name__} for dp in incorrect_type_dps],
             )
             )
+            if product_match:
+                _LOGGER.warning(
+                    "Product matches %s but dps mismatched",
+                    self.name,
+                )
+            return False
 
 
-        return len(missing_dps) == 0 and len(incorrect_type_dps) == 0
+        return product_match or len(missing_dps) == 0
 
 
     def _get_all_dps(self):
     def _get_all_dps(self):
         all_dps_list = []
         all_dps_list = []
@@ -199,21 +213,27 @@ class TuyaDeviceConfig:
                 keys.remove(d.id)
                 keys.remove(d.id)
         return True
         return True
 
 
-    def match_quality(self, dps):
-        """Determine the match quality for the provided dps map."""
+    def match_quality(self, dps, product_ids=None):
+        """Determine the match quality for the provided dps map and product ids."""
+        product_match = 0
+        if product_ids:
+            for p in self._config.get("products", []):
+                if p.get("id", "MISSING_ID!?!") in product_ids:
+                    product_match = 100
+
         keys = list(dps.keys())
         keys = list(dps.keys())
         matched = []
         matched = []
         if "updated_at" in keys:
         if "updated_at" in keys:
             keys.remove("updated_at")
             keys.remove("updated_at")
         total = len(keys)
         total = len(keys)
         if total < 1:
         if total < 1:
-            return 0
+            return product_match
 
 
         for e in self.all_entities():
         for e in self.all_entities():
             if not self._entity_match_analyse(e, keys, matched, dps):
             if not self._entity_match_analyse(e, keys, matched, dps):
                 return 0
                 return 0
 
 
-        return round((total - len(keys)) * 100 / total)
+        return product_match or round((total - len(keys)) * 100 / total)
 
 
 
 
 class TuyaEntityConfig:
 class TuyaEntityConfig:
@@ -1010,12 +1030,13 @@ def available_configs():
                 yield basename
                 yield basename
 
 
 
 
-def possible_matches(dps):
-    """Return possible matching configs for a given set of dps values."""
+def possible_matches(dps, product_ids=None):
+    """Return possible matching configs for a given set of
+    dps values and product_ids."""
     for cfg in available_configs():
     for cfg in available_configs():
         parsed = TuyaDeviceConfig(cfg)
         parsed = TuyaDeviceConfig(cfg)
         try:
         try:
-            if parsed.matches(dps):
+            if parsed.matches(dps, product_ids):
                 yield parsed
                 yield parsed
         except TypeError:
         except TypeError:
             _LOGGER.error("Parse error in %s", cfg)
             _LOGGER.error("Parse error in %s", cfg)

+ 10 - 0
tests/test_device_config.py

@@ -768,3 +768,13 @@ class TestDeviceConfig(IsolatedAsyncioTestCase):
         mock_config = {"id": "1", "name": "test", "type": "string"}
         mock_config = {"id": "1", "name": "test", "type": "string"}
         cfg = TuyaDpsConfig(mock_entity, mock_config)
         cfg = TuyaDpsConfig(mock_entity, mock_config)
         self.assertIsNone(cfg.default)
         self.assertIsNone(cfg.default)
+
+    def test_matching_with_product_id(self):
+        """Test that matching with product id works"""
+        cfg = get_config("smartplugv1")
+        self.assertTrue(cfg.matches({}, ["37mnhia3pojleqfh"]))
+
+    def test_matched_product_id_with_conflict_rejected(self):
+        """Test that matching with product id fails when there is a conflict"""
+        cfg = get_config("smartplugv1")
+        self.assertFalse(cfg.matches({"1": "wrong_type"}, ["37mnhia3pojleqfh"]))