Преглед изворни кода

Cloud refactoring: hass is not available in constructor

Need to initialise the Cloud on demand, as the hass object is not
yet available in the config flow constructor.

Cache and reuse authentication between instances.
Jason Rumney пре 1 година
родитељ
комит
38744c037a
2 измењених фајлова са 13 додато и 2 уклоњено
  1. 4 0
      custom_components/tuya_local/cloud.py
  2. 9 2
      custom_components/tuya_local/config_flow.py

+ 4 - 0
custom_components/tuya_local/cloud.py

@@ -55,6 +55,9 @@ class Cloud:
         self.__hass = hass
         self.__hass = hass
         self.__error_code = None
         self.__error_code = None
         self.__error_msg = None
         self.__error_msg = None
+        # Restore cached authentication
+        if cached := self.__hass.data[DOMAIN].get("auth_cache"):
+            self.__authentication = cached
 
 
     async def async_get_qr_code(self, user_code: str | None = None) -> bool:
     async def async_get_qr_code(self, user_code: str | None = None) -> bool:
         """Get QR code from Tuya server for user code authentication."""
         """Get QR code from Tuya server for user code authentication."""
@@ -105,6 +108,7 @@ class Cloud:
                     "refresh_token": info["refresh_token"],
                     "refresh_token": info["refresh_token"],
                 },
                 },
             }
             }
+            self.__hass.data[DOMAIN]["auth_cache"] = self.__authentication
         else:
         else:
             self.__error_code = info.get(TUYA_RESPONSE_CODE, {})
             self.__error_code = info.get(TUYA_RESPONSE_CODE, {})
             self.__error_msg = info.get(TUYA_RESPONSE_MSG, "Unknown error")
             self.__error_msg = info.get(TUYA_RESPONSE_MSG, "Unknown error")

+ 9 - 2
custom_components/tuya_local/config_flow.py

@@ -54,7 +54,11 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
 
 
     def __init__(self) -> None:
     def __init__(self) -> None:
         """Initialize the config flow."""
         """Initialize the config flow."""
-        self.cloud = Cloud(self.hass)
+        self.cloud = None
+
+    def init_cloud(self):
+        if self.cloud is None:
+            self.cloud = Cloud(self.hass)
 
 
     async def async_step_user(self, user_input=None):
     async def async_step_user(self, user_input=None):
         errors = {}
         errors = {}
@@ -66,6 +70,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
 
 
         if user_input is not None:
         if user_input is not None:
             if user_input["setup_mode"] == "cloud":
             if user_input["setup_mode"] == "cloud":
+                self.init_cloud()
                 try:
                 try:
                     if self.cloud.is_authenticated:
                     if self.cloud.is_authenticated:
                         self.__cloud_devices = await self.cloud.async_get_devices()
                         self.__cloud_devices = await self.cloud.async_get_devices()
@@ -101,6 +106,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
         """Step user."""
         """Step user."""
         errors = {}
         errors = {}
         placeholders = {}
         placeholders = {}
+        self.init_cloud()
 
 
         if user_input is not None:
         if user_input is not None:
             response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
             response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
@@ -146,7 +152,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                     }
                     }
                 ),
                 ),
             )
             )
-
+        self.init_cloud()
         if not await self.cloud.async_login():
         if not await self.cloud.async_login():
             # Try to get a new QR code on failure
             # Try to get a new QR code on failure
             response = await self.cloud.async_get_qr_code()
             response = await self.cloud.async_get_qr_code()
@@ -370,6 +376,7 @@ class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
                 self.__cloud_device["product_name"],
                 self.__cloud_device["product_name"],
                 self.__cloud_device["product_id"],
                 self.__cloud_device["product_id"],
             )
             )
+            self.init_cloud()
             response = self.cloud.async_get_datamodel(self.__cloud_device["device_id"])
             response = self.cloud.async_get_datamodel(self.__cloud_device["device_id"])
             if response and response["result"] and response["result"]["model"]:
             if response and response["result"] and response["result"]["model"]:
                 model = json.loads(response["result"]["model"])
                 model = json.loads(response["result"]["model"])