Просмотр исходного кода

config_flow: use product_ids from cloud flow in device matching

Since we can now collect product ids from cloud and local discovery
(currently only used during cloud flow), we can make use of this in
device matching.

If product id matches, return 100% for the match, provided there are
no conflicts in dp types.  Do this even if no dps are available from
the device, to help with matching those devices that don't return much
or any data.
Jason Rumney 1 год назад
Родитель
Сommit
e42272bdb7

+ 14 - 1
custom_components/tuya_local/config_flow.py

@@ -326,6 +326,16 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
             self.device = await async_test_connection(user_input, self.hass)
             if self.device:
                 self.data = user_input
+                if self.__cloud_device:
+                    if self.__cloud_device.get("product_id"):
+                        self.device.set_detected_product_id(
+                            self.__cloud_device["product_id"]
+                        )
+                    if self.__cloud_device.get("local_product_id"):
+                        self.device.set_detected_product_id(
+                            self.__cloud_device["local_product_id"]
+                        )
+
                 return await self.async_step_select_type()
             else:
                 errors["base"] = "connection"
@@ -366,7 +376,10 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
 
         async for type in self.device.async_possible_types():
             types.append(type.config_type)
-            q = type.match_quality(self.device._get_cached_state())
+            q = type.match_quality(
+                self.device._get_cached_state(),
+                self.device._product_ids,
+            )
             if q > best_match:
                 best_match = q
                 best_matching_type = type.config_type

+ 6 - 1
custom_components/tuya_local/device.py

@@ -61,6 +61,7 @@ class TuyaLocalDevice(object):
         self._name = name
         self._children = []
         self._force_dps = []
+        self._product_ids = []
         self._running = False
         self._shutdown_listener = None
         self._startup_listener = None
@@ -368,6 +369,9 @@ class TuyaLocalDevice(object):
         if self._api.parent:
             self._api.parent.set_socketPersistent(False)
 
+    def set_detected_product_id(self, product_id):
+        self._product_ids.append(product_id)
+
     async def async_possible_types(self):
         cached_state = self._get_cached_state()
         if len(cached_state) <= 1:
@@ -395,6 +399,7 @@ class TuyaLocalDevice(object):
         for matched in await self._hass.async_add_executor_job(
             possible_matches,
             cached_state,
+            self._product_ids,
         ):
             await asyncio.sleep(0)
             yield matched
@@ -404,7 +409,7 @@ class TuyaLocalDevice(object):
         best_quality = 0
         cached_state = self._get_cached_state()
         async for config in self.async_possible_types():
-            quality = config.match_quality(cached_state)
+            quality = config.match_quality(cached_state, self._product_ids)
             _LOGGER.info(
                 "%s considering %s with quality %s",
                 self.name,

+ 30 - 9
custom_components/tuya_local/helpers/device_config.py

@@ -141,7 +141,15 @@ class TuyaDeviceConfig:
         for e in self.secondary_entities():
             yield e
 
-    def matches(self, dps):
+    def matches(self, dps, product_ids):
+        """Determine whether this config matches the provided dps map or
+        product ids."""
+        product_match = False
+        if product_ids:
+            for p in self._config.get("products", []):
+                if p.get("id", "MISSING_ID!?!") in product_ids:
+                    product_match = True
+
         required_dps = self._get_required_dps()
 
         missing_dps = [dp for dp in required_dps if dp.id not in dps.keys()]
@@ -163,8 +171,14 @@ class TuyaDeviceConfig:
                 self.name,
                 [{dp.id: dp.type.__name__} for dp in incorrect_type_dps],
             )
+            if product_match:
+                _LOGGER.warning(
+                    "Product matches %s but dps mismatched",
+                    self.name,
+                )
+            return False
 
-        return len(missing_dps) == 0 and len(incorrect_type_dps) == 0
+        return product_match or len(missing_dps) == 0
 
     def _get_all_dps(self):
         all_dps_list = []
@@ -199,21 +213,27 @@ class TuyaDeviceConfig:
                 keys.remove(d.id)
         return True
 
-    def match_quality(self, dps):
-        """Determine the match quality for the provided dps map."""
+    def match_quality(self, dps, product_ids=None):
+        """Determine the match quality for the provided dps map and product ids."""
+        product_match = 0
+        if product_ids:
+            for p in self._config.get("products", []):
+                if p.get("id", "MISSING_ID!?!") in product_ids:
+                    product_match = 100
+
         keys = list(dps.keys())
         matched = []
         if "updated_at" in keys:
             keys.remove("updated_at")
         total = len(keys)
         if total < 1:
-            return 0
+            return product_match
 
         for e in self.all_entities():
             if not self._entity_match_analyse(e, keys, matched, dps):
                 return 0
 
-        return round((total - len(keys)) * 100 / total)
+        return product_match or round((total - len(keys)) * 100 / total)
 
 
 class TuyaEntityConfig:
@@ -1010,12 +1030,13 @@ def available_configs():
                 yield basename
 
 
-def possible_matches(dps):
-    """Return possible matching configs for a given set of dps values."""
+def possible_matches(dps, product_ids=None):
+    """Return possible matching configs for a given set of
+    dps values and product_ids."""
     for cfg in available_configs():
         parsed = TuyaDeviceConfig(cfg)
         try:
-            if parsed.matches(dps):
+            if parsed.matches(dps, product_ids):
                 yield parsed
         except TypeError:
             _LOGGER.error("Parse error in %s", cfg)

+ 10 - 0
tests/test_device_config.py

@@ -768,3 +768,13 @@ class TestDeviceConfig(IsolatedAsyncioTestCase):
         mock_config = {"id": "1", "name": "test", "type": "string"}
         cfg = TuyaDpsConfig(mock_entity, mock_config)
         self.assertIsNone(cfg.default)
+
+    def test_matching_with_product_id(self):
+        """Test that matching with product id works"""
+        cfg = get_config("smartplugv1")
+        self.assertTrue(cfg.matches({}, ["37mnhia3pojleqfh"]))
+
+    def test_matched_product_id_with_conflict_rejected(self):
+        """Test that matching with product id fails when there is a conflict"""
+        cfg = get_config("smartplugv1")
+        self.assertFalse(cfg.matches({"1": "wrong_type"}, ["37mnhia3pojleqfh"]))