Explorar el Código

Fix protocol version dropdown type mismatch (#4646)

* fix protocol selection dropdown

* add test to catch protocol version type mismatch in dropdown
kongo09 hace 2 días
padre
commit
47b8fa554d
Se han modificado 2 ficheros con 42 adiciones y 6 borrados
  1. 11 5
      custom_components/tuya_local/config_flow.py
  2. 31 1
      tests/test_config_flow.py

+ 11 - 5
custom_components/tuya_local/config_flow.py

@@ -343,11 +343,14 @@ class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
             host_opts = {"default": self.__cloud_device.get("ip")}
             host_opts = {"default": self.__cloud_device.get("ip")}
             key_opts = {"default": self.__cloud_device.get(CONF_LOCAL_KEY)}
             key_opts = {"default": self.__cloud_device.get(CONF_LOCAL_KEY)}
             if self.__cloud_device.get("version"):
             if self.__cloud_device.get("version"):
-                proto_opts = {"default": float(self.__cloud_device.get("version"))}
+                proto_opts = {"default": str(self.__cloud_device.get("version"))}
             if self.__cloud_device.get(CONF_DEVICE_CID):
             if self.__cloud_device.get(CONF_DEVICE_CID):
                 devcid_opts = {"default": self.__cloud_device.get(CONF_DEVICE_CID)}
                 devcid_opts = {"default": self.__cloud_device.get(CONF_DEVICE_CID)}
 
 
         if user_input is not None:
         if user_input is not None:
+            proto = user_input.get(CONF_PROTOCOL_VERSION)
+            if proto != "auto":
+                user_input[CONF_PROTOCOL_VERSION] = float(proto)
             self.device = await async_test_connection(user_input, self.hass)
             self.device = await async_test_connection(user_input, self.hass)
             if self.device:
             if self.device:
                 self.data = user_input
                 self.data = user_input
@@ -384,7 +387,7 @@ class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
                 key_opts["default"] = user_input[CONF_LOCAL_KEY]
                 key_opts["default"] = user_input[CONF_LOCAL_KEY]
                 if CONF_DEVICE_CID in user_input:
                 if CONF_DEVICE_CID in user_input:
                     devcid_opts["default"] = user_input[CONF_DEVICE_CID]
                     devcid_opts["default"] = user_input[CONF_DEVICE_CID]
-                proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
+                proto_opts["default"] = str(user_input[CONF_PROTOCOL_VERSION])
                 polling_opts["default"] = user_input[CONF_POLL_ONLY]
                 polling_opts["default"] = user_input[CONF_POLL_ONLY]
 
 
         return self.async_show_form(
         return self.async_show_form(
@@ -397,7 +400,7 @@ class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
                     vol.Required(
                     vol.Required(
                         CONF_PROTOCOL_VERSION,
                         CONF_PROTOCOL_VERSION,
                         **proto_opts,
                         **proto_opts,
-                    ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
+                    ): vol.In(["auto"] + [str(v) for v in API_PROTOCOL_VERSIONS]),
                     vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
                     vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
                     vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
                     vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
                 }
                 }
@@ -554,6 +557,9 @@ class OptionsFlowHandler(OptionsFlow):
         config = {**self.config_entry.data, **self.config_entry.options}
         config = {**self.config_entry.data, **self.config_entry.options}
 
 
         if user_input is not None:
         if user_input is not None:
+            proto = user_input.get(CONF_PROTOCOL_VERSION)
+            if proto != "auto":
+                user_input[CONF_PROTOCOL_VERSION] = float(proto)
             config = {**config, **user_input}
             config = {**config, **user_input}
             device = await async_test_connection(config, self.hass)
             device = await async_test_connection(config, self.hass)
             if device:
             if device:
@@ -569,8 +575,8 @@ class OptionsFlowHandler(OptionsFlow):
             vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
             vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
             vol.Required(
             vol.Required(
                 CONF_PROTOCOL_VERSION,
                 CONF_PROTOCOL_VERSION,
-                default=config.get(CONF_PROTOCOL_VERSION, "auto"),
-            ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
+                default=str(config.get(CONF_PROTOCOL_VERSION, "auto")),
+            ): vol.In(["auto"] + [str(v) for v in API_PROTOCOL_VERSIONS]),
             vol.Required(
             vol.Required(
                 CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
                 CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
             ): bool,
             ): bool,

+ 31 - 1
tests/test_config_flow.py

@@ -261,6 +261,36 @@ async def test_flow_user_init(hass, mocker):
         pass
         pass
 
 
 
 
+@pytest.mark.asyncio
+async def test_flow_user_init_protocol_options_are_strings(hass, mocker):
+    """Test that protocol version dropdown uses strings, not floats."""
+    result = await hass.config_entries.flow.async_init(
+        DOMAIN, context={"source": "local"}
+    )
+    schema = result["data_schema"]
+    # Validate that string protocol versions are accepted
+    schema(
+        {
+            CONF_DEVICE_ID: "test",
+            CONF_LOCAL_KEY: TESTKEY,
+            CONF_HOST: "test",
+            CONF_PROTOCOL_VERSION: "3.3",
+            CONF_POLL_ONLY: False,
+        }
+    )
+    # Validate that float protocol versions are rejected
+    with pytest.raises(vol.MultipleInvalid):
+        schema(
+            {
+                CONF_DEVICE_ID: "test",
+                CONF_LOCAL_KEY: TESTKEY,
+                CONF_HOST: "test",
+                CONF_PROTOCOL_VERSION: 3.3,
+                CONF_POLL_ONLY: False,
+            }
+        )
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_async_test_connection_valid(hass, mocker):
 async def test_async_test_connection_valid(hass, mocker):
     """Test that device is returned when connection is valid."""
     """Test that device is returned when connection is valid."""
@@ -622,7 +652,7 @@ async def test_options_flow_modifies_config(hass, bypass_setup, mocker):
             CONF_HOST: "new_hostname",
             CONF_HOST: "new_hostname",
             CONF_LOCAL_KEY: "new_key",
             CONF_LOCAL_KEY: "new_key",
             CONF_POLL_ONLY: False,
             CONF_POLL_ONLY: False,
-            CONF_PROTOCOL_VERSION: 3.3,
+            CONF_PROTOCOL_VERSION: "3.3",
         },
         },
     )
     )
     expected = {
     expected = {