Просмотр исходного кода

Show manufacturer and model in config flow and HA device registry (#4609)

* Show manufacturer and model in config flow and device details

* fix linting error

* improve displayed device choice

* fix device info so that tests pass

* fix tests
kongo09 4 дней назад
Родитель
Сommit
0495c6e406

+ 40 - 14
custom_components/tuya_local/config_flow.py

@@ -31,6 +31,8 @@ from .const import (
     CONF_DEVICE_CID,
     CONF_DEVICE_CID,
     CONF_DEVICE_ID,
     CONF_DEVICE_ID,
     CONF_LOCAL_KEY,
     CONF_LOCAL_KEY,
+    CONF_MANUFACTURER,
+    CONF_MODEL,
     CONF_POLL_ONLY,
     CONF_POLL_ONLY,
     CONF_PROTOCOL_VERSION,
     CONF_PROTOCOL_VERSION,
     CONF_TYPE,
     CONF_TYPE,
@@ -406,22 +408,47 @@ class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
 
 
     async def async_step_select_type(self, user_input=None):
     async def async_step_select_type(self, user_input=None):
         if user_input is not None:
         if user_input is not None:
-            self.data[CONF_TYPE] = user_input[CONF_TYPE]
+            # Value is a compound key: "config_type||manufacturer||model"
+            parts = user_input[CONF_TYPE].split("||", 2)
+            self.data[CONF_TYPE] = parts[0]
+            if len(parts) > 1 and parts[1]:
+                self.data[CONF_MANUFACTURER] = parts[1]
+            if len(parts) > 2 and parts[2]:
+                self.data[CONF_MODEL] = parts[2]
             return await self.async_step_choose_entities()
             return await self.async_step_choose_entities()
 
 
-        types = []
+        all_matches = []
         best_match = 0
         best_match = 0
         best_matching_type = None
         best_matching_type = None
+        best_matching_key = None
+        has_product_id_match = False
 
 
         for type in await self.device.async_possible_types():
         for type in await self.device.async_possible_types():
-            types.append(type.config_type)
             q = type.match_quality(
             q = type.match_quality(
                 self.device._get_cached_state(),
                 self.device._get_cached_state(),
                 self.device._product_ids,
                 self.device._product_ids,
             )
             )
-            if q > best_match:
-                best_match = q
-                best_matching_type = type.config_type
+            if q > 100:
+                has_product_id_match = True
+            for manufacturer, model in type.product_display_entries(
+                self.device._product_ids
+            ):
+                key = f"{type.config_type}||{manufacturer or ''}||{model or ''}"
+                parts = [p for p in [manufacturer, model] if p]
+                if parts:
+                    label = f"{' '.join(parts)} ({type.config_type})"
+                else:
+                    label = f"{type.name} ({type.config_type})"
+                all_matches.append((SelectOptionDict(value=key, label=label), q))
+                if q > best_match:
+                    best_match = q
+                    best_matching_type = type.config_type
+                    best_matching_key = key
+
+        if has_product_id_match:
+            type_options = [opt for opt, q in all_matches if q > 100]
+        else:
+            type_options = [opt for opt, _ in all_matches]
 
 
         best_match = int(best_match)
         best_match = int(best_match)
         dps = self.device._get_cached_state()
         dps = self.device._get_cached_state()
@@ -463,14 +490,14 @@ class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
         _LOGGER.warning(
         _LOGGER.warning(
             "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
             "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
         )
         )
-        if types:
+        if type_options:
             detected = getattr(self, "_auto_detected_protocol", None)
             detected = getattr(self, "_auto_detected_protocol", None)
             schema = vol.Schema(
             schema = vol.Schema(
                 {
                 {
                     vol.Required(
                     vol.Required(
                         CONF_TYPE,
                         CONF_TYPE,
-                        default=best_matching_type,
-                    ): vol.In(types),
+                        default=best_matching_key,
+                    ): SelectSelector(SelectSelectorConfig(options=type_options)),
                 }
                 }
             )
             )
             if detected:
             if detected:
@@ -490,17 +517,16 @@ class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
         return await self.async_step_select_type(user_input)
         return await self.async_step_select_type(user_input)
 
 
     async def async_step_choose_entities(self, user_input=None):
     async def async_step_choose_entities(self, user_input=None):
+        config = await self.hass.async_add_executor_job(
+            get_config,
+            self.data[CONF_TYPE],
+        )
         if user_input is not None:
         if user_input is not None:
             title = user_input[CONF_NAME]
             title = user_input[CONF_NAME]
             del user_input[CONF_NAME]
             del user_input[CONF_NAME]
-
             return self.async_create_entry(
             return self.async_create_entry(
                 title=title, data={**self.data, **user_input}
                 title=title, data={**self.data, **user_input}
             )
             )
-        config = await self.hass.async_add_executor_job(
-            get_config,
-            self.data[CONF_TYPE],
-        )
         schema = {vol.Required(CONF_NAME, default=config.name): str}
         schema = {vol.Required(CONF_NAME, default=config.name): str}
 
 
         return self.async_show_form(
         return self.async_show_form(

+ 2 - 0
custom_components/tuya_local/const.py

@@ -4,6 +4,8 @@ DATA_STORE = "store"
 CONF_DEVICE_ID = "device_id"
 CONF_DEVICE_ID = "device_id"
 CONF_LOCAL_KEY = "local_key"
 CONF_LOCAL_KEY = "local_key"
 CONF_TYPE = "type"
 CONF_TYPE = "type"
+CONF_MANUFACTURER = "manufacturer"
+CONF_MODEL = "model"
 CONF_POLL_ONLY = "poll_only"
 CONF_POLL_ONLY = "poll_only"
 CONF_DEVICE_CID = "device_cid"
 CONF_DEVICE_CID = "device_cid"
 CONF_PROTOCOL_VERSION = "protocol_version"
 CONF_PROTOCOL_VERSION = "protocol_version"

+ 16 - 3
custom_components/tuya_local/device.py

@@ -22,6 +22,8 @@ from .const import (
     CONF_DEVICE_CID,
     CONF_DEVICE_CID,
     CONF_DEVICE_ID,
     CONF_DEVICE_ID,
     CONF_LOCAL_KEY,
     CONF_LOCAL_KEY,
+    CONF_MANUFACTURER,
+    CONF_MODEL,
     CONF_POLL_ONLY,
     CONF_POLL_ONLY,
     CONF_PROTOCOL_VERSION,
     CONF_PROTOCOL_VERSION,
     DOMAIN,
     DOMAIN,
@@ -49,6 +51,8 @@ class TuyaLocalDevice(object):
         dev_cid,
         dev_cid,
         hass: HomeAssistant,
         hass: HomeAssistant,
         poll_only=False,
         poll_only=False,
+        manufacturer=None,
+        model=None,
     ):
     ):
         """
         """
         Represents a Tuya-based device.
         Represents a Tuya-based device.
@@ -61,9 +65,13 @@ class TuyaLocalDevice(object):
             protocol_version (str | number): The protocol version.
             protocol_version (str | number): The protocol version.
             dev_cid (str): The sub device id.
             dev_cid (str): The sub device id.
             hass (HomeAssistant): The Home Assistant instance.
             hass (HomeAssistant): The Home Assistant instance.
-            poll_only (bool): True if the device should be polled only
+            poll_only (bool): True if the device should be polled only.
+            manufacturer (str | None): The device manufacturer, if known.
+            model (str | None): The device model, if known.
         """
         """
         self._name = name
         self._name = name
+        self._manufacturer = manufacturer
+        self._model = model
         self._children = []
         self._children = []
         self._force_dps = []
         self._force_dps = []
         self._product_ids = []
         self._product_ids = []
@@ -161,11 +169,14 @@ class TuyaLocalDevice(object):
     @property
     @property
     def device_info(self):
     def device_info(self):
         """Return the device information for this device."""
         """Return the device information for this device."""
-        return {
+        info = {
             "identifiers": {(DOMAIN, self.unique_id)},
             "identifiers": {(DOMAIN, self.unique_id)},
             "name": self.name,
             "name": self.name,
-            "manufacturer": "Tuya",
+            "manufacturer": self._manufacturer or "Tuya",
         }
         }
+        if self._model:
+            info["model"] = self._model
+        return info
 
 
     @property
     @property
     def has_returned_state(self):
     def has_returned_state(self):
@@ -763,6 +774,8 @@ def setup_device(hass: HomeAssistant, config: dict):
         config.get(CONF_DEVICE_CID),
         config.get(CONF_DEVICE_CID),
         hass,
         hass,
         config[CONF_POLL_ONLY],
         config[CONF_POLL_ONLY],
+        manufacturer=config.get(CONF_MANUFACTURER),
+        model=config.get(CONF_MODEL),
     )
     )
     hass.data[DOMAIN][get_device_id(config)] = {
     hass.data[DOMAIN][get_device_id(config)] = {
         "device": device,
         "device": device,

+ 27 - 0
custom_components/tuya_local/helpers/device_config.py

@@ -229,6 +229,33 @@ class TuyaDeviceConfig:
 
 
         return product_match or round((total - len(keys)) * 100 / total)
         return product_match or round((total - len(keys)) * 100 / total)
 
 
+    def product_display_entries(self, product_ids=None):
+        """Return distinct (manufacturer, model) pairs for display in the config flow.
+
+        When product_ids is provided, only products whose id matches are
+        included.  When there are no confirmed matches (or no product_ids),
+        returns [(None, None)] so the caller falls back to the config filename.
+        """
+        seen = set()
+        result = []
+
+        for p in self._config.get("products", []):
+            if product_ids and p.get("id") not in product_ids:
+                continue
+            manufacturer = p.get("manufacturer")
+            model = p.get("model")
+            if manufacturer is None and model is None:
+                continue
+            key = (manufacturer, model)
+            if key not in seen:
+                seen.add(key)
+                result.append(key)
+
+        if not result:
+            result.append((None, None))
+
+        return result
+
 
 
 class TuyaEntityConfig:
 class TuyaEntityConfig:
     """Representation of an entity config for a supported entity."""
     """Representation of an entity config for a supported entity."""

+ 4 - 3
tests/test_config_flow.py

@@ -366,6 +366,7 @@ def setup_device_mock(mock, mocker, failure=False, type="test"):
     mock_type.legacy_type = type
     mock_type.legacy_type = type
     mock_type.config_type = type
     mock_type.config_type = type
     mock_type.match_quality.return_value = 100
     mock_type.match_quality.return_value = 100
+    mock_type.product_display_entries.return_value = [(None, None)]
     mock.async_possible_types = mocker.AsyncMock(
     mock.async_possible_types = mocker.AsyncMock(
         return_value=[mock_type] if not failure else []
         return_value=[mock_type] if not failure else []
     )
     )
@@ -422,11 +423,11 @@ async def test_flow_select_type_init(hass, mocker):
     # Check the schema.  Simple comparison does not work since they are not
     # Check the schema.  Simple comparison does not work since they are not
     # the same object
     # the same object
     try:
     try:
-        result["data_schema"]({CONF_TYPE: "test"})
+        result["data_schema"]({CONF_TYPE: "test||||"})
     except vol.MultipleInvalid:
     except vol.MultipleInvalid:
         assert False
         assert False
     try:
     try:
-        result["data_schema"]({CONF_TYPE: "not_test"})
+        result["data_schema"]({CONF_TYPE: "not_test||||"})
         assert False
         assert False
     except vol.MultipleInvalid:
     except vol.MultipleInvalid:
         pass
         pass
@@ -458,7 +459,7 @@ async def test_flow_select_type_data_valid(hass, mocker):
     )
     )
     result = await hass.config_entries.flow.async_configure(
     result = await hass.config_entries.flow.async_configure(
         flow["flow_id"],
         flow["flow_id"],
-        user_input={CONF_TYPE: "smartplugv1"},
+        user_input={CONF_TYPE: "smartplugv1||||"},
     )
     )
     assert "form" == result["type"]
     assert "form" == result["type"]
     assert "choose_entities" == result["step_id"]
     assert "choose_entities" == result["step_id"]