config_flow.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. import asyncio
  2. import logging
  3. from collections import OrderedDict
  4. from typing import Any
  5. import tinytuya
  6. import voluptuous as vol
  7. from homeassistant.config_entries import (
  8. CONN_CLASS_LOCAL_PUSH,
  9. ConfigEntry,
  10. ConfigFlow,
  11. OptionsFlow,
  12. )
  13. from homeassistant.const import CONF_HOST, CONF_NAME
  14. from homeassistant.core import HomeAssistant, callback
  15. from homeassistant.data_entry_flow import FlowResult
  16. from homeassistant.helpers.selector import (
  17. QrCodeSelector,
  18. QrCodeSelectorConfig,
  19. QrErrorCorrectionLevel,
  20. SelectOptionDict,
  21. SelectSelector,
  22. SelectSelectorConfig,
  23. SelectSelectorMode,
  24. )
  25. from . import DOMAIN
  26. from .cloud import Cloud
  27. from .const import (
  28. API_PROTOCOL_VERSIONS,
  29. CONF_DEVICE_CID,
  30. CONF_DEVICE_ID,
  31. CONF_LOCAL_KEY,
  32. CONF_POLL_ONLY,
  33. CONF_PROTOCOL_VERSION,
  34. CONF_TYPE,
  35. CONF_USER_CODE,
  36. DATA_STORE,
  37. )
  38. from .device import TuyaLocalDevice
  39. from .helpers.config import get_device_id
  40. from .helpers.device_config import get_config
  41. from .helpers.log import log_json
  42. _LOGGER = logging.getLogger(__name__)
  43. DEVICE_DETAILS_URL = (
  44. "https://github.com/make-all/tuya-local/blob/main/DEVICE_DETAILS.md"
  45. "#finding-your-device-id-and-local-key"
  46. )
  47. class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
  48. VERSION = 13
  49. MINOR_VERSION = 16
  50. CONNECTION_CLASS = CONN_CLASS_LOCAL_PUSH
  51. device = None
  52. data = {}
  53. __qr_code: str | None = None
  54. __cloud_devices: dict[str, Any] = {}
  55. __cloud_device: dict[str, Any] | None = None
  56. def __init__(self) -> None:
  57. """Initialize the config flow."""
  58. self.cloud = None
  59. def init_cloud(self):
  60. if self.cloud is None:
  61. self.cloud = Cloud(self.hass)
  62. async def async_step_user(self, user_input=None):
  63. errors = {}
  64. if self.hass.data.get(DOMAIN) is None:
  65. self.hass.data[DOMAIN] = {}
  66. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  67. self.hass.data[DOMAIN][DATA_STORE] = {}
  68. if user_input is not None:
  69. mode = user_input.get("setup_mode")
  70. if mode == "cloud" or mode == "cloud_fresh_login":
  71. self.init_cloud()
  72. try:
  73. if mode == "cloud_fresh_login":
  74. # Force a fresh login
  75. self.cloud.logout()
  76. if self.cloud.is_authenticated:
  77. self.__cloud_devices = await self.cloud.async_get_devices()
  78. return await self.async_step_choose_device()
  79. except Exception as e:
  80. # Re-authentication is needed.
  81. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  82. _LOGGER.warning("Re-authentication is required.")
  83. return await self.async_step_cloud()
  84. if mode == "manual":
  85. return await self.async_step_local()
  86. # Build form
  87. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  88. fields[vol.Required("setup_mode")] = SelectSelector(
  89. SelectSelectorConfig(
  90. options=["cloud", "manual", "cloud_fresh_login"],
  91. mode=SelectSelectorMode.LIST,
  92. translation_key="setup_mode",
  93. )
  94. )
  95. return self.async_show_form(
  96. step_id="user",
  97. data_schema=vol.Schema(fields),
  98. errors=errors or {},
  99. last_step=False,
  100. )
  101. async def async_step_cloud(
  102. self, user_input: dict[str, Any] | None = None
  103. ) -> FlowResult:
  104. """Step user."""
  105. errors = {}
  106. placeholders = {}
  107. self.init_cloud()
  108. if user_input is not None:
  109. response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
  110. if response:
  111. self.__qr_code = response
  112. return await self.async_step_scan()
  113. errors["base"] = "login_error"
  114. placeholders = self.cloud.last_error
  115. else:
  116. user_input = {}
  117. return self.async_show_form(
  118. step_id="cloud",
  119. data_schema=vol.Schema(
  120. {
  121. vol.Required(
  122. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  123. ): str,
  124. }
  125. ),
  126. errors=errors,
  127. description_placeholders=placeholders,
  128. )
  129. async def async_step_scan(
  130. self, user_input: dict[str, Any] | None = None
  131. ) -> FlowResult:
  132. """Step scan."""
  133. if user_input is None:
  134. return self.async_show_form(
  135. step_id="scan",
  136. data_schema=vol.Schema(
  137. {
  138. vol.Optional("QR"): QrCodeSelector(
  139. config=QrCodeSelectorConfig(
  140. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  141. scale=5,
  142. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  143. )
  144. )
  145. }
  146. ),
  147. )
  148. self.init_cloud()
  149. if not await self.cloud.async_login():
  150. # Try to get a new QR code on failure
  151. response = await self.cloud.async_get_qr_code()
  152. errors = {"base": "login_error"}
  153. placeholders = self.cloud.last_error
  154. if response:
  155. self.__qr_code = response
  156. return self.async_show_form(
  157. step_id="scan",
  158. errors=errors,
  159. data_schema=vol.Schema(
  160. {
  161. vol.Optional("QR"): QrCodeSelector(
  162. config=QrCodeSelectorConfig(
  163. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  164. scale=5,
  165. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  166. )
  167. )
  168. }
  169. ),
  170. description_placeholders=placeholders,
  171. )
  172. self.__cloud_devices = await self.cloud.async_get_devices()
  173. return await self.async_step_choose_device()
  174. async def async_step_choose_device(self, user_input=None):
  175. errors = {}
  176. if user_input is not None:
  177. device_choice = self.__cloud_devices[user_input["device_id"]]
  178. if device_choice["ip"] != "":
  179. # This is a directly addable device.
  180. if user_input["hub_id"] == "None":
  181. device_choice["ip"] = ""
  182. self.__cloud_device = device_choice
  183. return await self.async_step_search()
  184. else:
  185. # Show error if user selected a hub.
  186. errors["base"] = "does_not_need_hub"
  187. # Fall through to reshow the form.
  188. else:
  189. # This is an indirectly addressable device. Need to know which hub it is connected to.
  190. if user_input["hub_id"] != "None":
  191. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  192. # Populate node_id or uuid and local_key from the child
  193. # device to pass on complete information to the local step.
  194. hub_choice["ip"] = ""
  195. hub_choice[CONF_DEVICE_CID] = (
  196. device_choice["node_id"] or device_choice["uuid"]
  197. )
  198. if device_choice.get(CONF_LOCAL_KEY):
  199. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  200. # Communicate the sub device product id to help match the
  201. # correect device config in the next step.
  202. hub_choice["product_id"] = device_choice["product_id"]
  203. self.__cloud_device = hub_choice
  204. return await self.async_step_search()
  205. else:
  206. # Show error if user did not select a hub.
  207. errors["base"] = "needs_hub"
  208. # Fall through to reshow the form.
  209. device_list = []
  210. for key in self.__cloud_devices.keys():
  211. device_entry = self.__cloud_devices[key]
  212. if device_entry.get("exists"):
  213. continue
  214. if device_entry[CONF_LOCAL_KEY] != "":
  215. if device_entry["online"]:
  216. device_list.append(
  217. SelectOptionDict(
  218. value=key,
  219. label=f"{device_entry['name']} ({device_entry['product_name']})",
  220. )
  221. )
  222. else:
  223. device_list.append(
  224. SelectOptionDict(
  225. value=key,
  226. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  227. )
  228. )
  229. _LOGGER.debug(f"Device count: {len(device_list)}")
  230. if len(device_list) == 0:
  231. return self.async_abort(reason="no_devices")
  232. device_selector = SelectSelector(
  233. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  234. )
  235. hub_list = []
  236. hub_list.append(SelectOptionDict(value="None", label="None"))
  237. for key in self.__cloud_devices.keys():
  238. hub_entry = self.__cloud_devices[key]
  239. if hub_entry["is_hub"]:
  240. hub_list.append(
  241. SelectOptionDict(
  242. value=key,
  243. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  244. )
  245. )
  246. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  247. hub_selector = SelectSelector(
  248. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  249. )
  250. # Build form
  251. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  252. fields[vol.Required("device_id")] = device_selector
  253. fields[vol.Required("hub_id")] = hub_selector
  254. return self.async_show_form(
  255. step_id="choose_device",
  256. data_schema=vol.Schema(fields),
  257. errors=errors or {},
  258. last_step=False,
  259. )
  260. async def async_step_search(self, user_input=None):
  261. if user_input is not None:
  262. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  263. # This scan will take 18s with the default settings. If we cannot find the device we
  264. # will just leave the IP address blank and hope the user can discover the IP by other
  265. # means such as router device IP assignments.
  266. _LOGGER.debug(
  267. f"Scanning network to get IP address for {self.__cloud_device.get('id', 'DEVICE_KEY_UNAVAILABLE')}."
  268. )
  269. self.__cloud_device["ip"] = ""
  270. try:
  271. local_device = await self.hass.async_add_executor_job(
  272. scan_for_device, self.__cloud_device.get("id")
  273. )
  274. except OSError:
  275. local_device = {"ip": None, "version": ""}
  276. if local_device.get("ip"):
  277. _LOGGER.debug(f"Found: {local_device}")
  278. self.__cloud_device["ip"] = local_device.get("ip")
  279. self.__cloud_device["version"] = local_device.get("version")
  280. if not self.__cloud_device.get(CONF_DEVICE_CID):
  281. self.__cloud_device["local_product_id"] = local_device.get(
  282. "productKey"
  283. )
  284. else:
  285. _LOGGER.warning(
  286. f"Could not find device: {self.__cloud_device.get('id', 'DEVICE_KEY_UNAVAILABLE')}"
  287. )
  288. return await self.async_step_local()
  289. return self.async_show_form(
  290. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  291. )
  292. async def async_step_local(self, user_input=None):
  293. errors = {}
  294. devid_opts = {}
  295. host_opts = {"default": ""}
  296. key_opts = {}
  297. proto_opts = {"default": "auto"}
  298. polling_opts = {"default": False}
  299. devcid_opts = {}
  300. if self.__cloud_device is not None:
  301. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  302. devid_opts = {"default": self.__cloud_device.get("id")}
  303. host_opts = {"default": self.__cloud_device.get("ip")}
  304. key_opts = {"default": self.__cloud_device.get(CONF_LOCAL_KEY)}
  305. if self.__cloud_device.get("version"):
  306. proto_opts = {"default": float(self.__cloud_device.get("version"))}
  307. if self.__cloud_device.get(CONF_DEVICE_CID):
  308. devcid_opts = {"default": self.__cloud_device.get(CONF_DEVICE_CID)}
  309. if user_input is not None:
  310. self.device = await async_test_connection(user_input, self.hass)
  311. if self.device:
  312. self.data = user_input
  313. # If auto mode found a working protocol, save it so future
  314. # HA restarts connect directly without re-cycling all versions.
  315. self._auto_detected_protocol = None
  316. if (
  317. user_input.get(CONF_PROTOCOL_VERSION) == "auto"
  318. and self.device._protocol_configured != "auto"
  319. ):
  320. self._auto_detected_protocol = self.device._protocol_configured
  321. self.data = {
  322. **self.data,
  323. CONF_PROTOCOL_VERSION: self._auto_detected_protocol,
  324. }
  325. if self.__cloud_device:
  326. if self.__cloud_device.get("product_id"):
  327. self.device.set_detected_product_id(
  328. self.__cloud_device.get("product_id")
  329. )
  330. if self.__cloud_device.get("local_product_id"):
  331. self.device.set_detected_product_id(
  332. self.__cloud_device.get("local_product_id")
  333. )
  334. await self.async_set_unique_id(
  335. user_input.get(CONF_DEVICE_CID, user_input[CONF_DEVICE_ID])
  336. )
  337. self._abort_if_unique_id_configured()
  338. return await self.async_step_select_type()
  339. else:
  340. errors["base"] = "connection"
  341. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  342. host_opts["default"] = user_input[CONF_HOST]
  343. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  344. if CONF_DEVICE_CID in user_input:
  345. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  346. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  347. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  348. return self.async_show_form(
  349. step_id="local",
  350. data_schema=vol.Schema(
  351. {
  352. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  353. vol.Required(CONF_HOST, **host_opts): str,
  354. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  355. vol.Required(
  356. CONF_PROTOCOL_VERSION,
  357. **proto_opts,
  358. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  359. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  360. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  361. }
  362. ),
  363. description_placeholders={"device_details_url": DEVICE_DETAILS_URL},
  364. errors=errors,
  365. )
  366. async def async_step_select_type(self, user_input=None):
  367. if user_input is not None:
  368. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  369. return await self.async_step_choose_entities()
  370. types = []
  371. best_match = 0
  372. best_matching_type = None
  373. for type in await self.device.async_possible_types():
  374. types.append(type.config_type)
  375. q = type.match_quality(
  376. self.device._get_cached_state(),
  377. self.device._product_ids,
  378. )
  379. if q > best_match:
  380. best_match = q
  381. best_matching_type = type.config_type
  382. best_match = int(best_match)
  383. dps = self.device._get_cached_state()
  384. if self.__cloud_device:
  385. _LOGGER.warning(
  386. "Adding %s device with product id %s",
  387. self.__cloud_device.get("product_name", "UNKNOWN"),
  388. self.__cloud_device.get("product_id", "UNKNOWN"),
  389. )
  390. if self.__cloud_device.get("local_product_id") and self.__cloud_device.get(
  391. "local_product_id"
  392. ) != self.__cloud_device.get("product_id"):
  393. _LOGGER.warning(
  394. "Local product id differs from cloud: %s",
  395. self.__cloud_device.get("local_product_id"),
  396. )
  397. try:
  398. self.init_cloud()
  399. model = await self.cloud.async_get_datamodel(
  400. self.__cloud_device.get("id"),
  401. )
  402. if model:
  403. _LOGGER.warning(
  404. "Partial cloud device spec:\n%s",
  405. log_json(model),
  406. )
  407. except Exception as e:
  408. _LOGGER.warning(
  409. "Unable to fetch data model from cloud: %s %s",
  410. type(e).__name__,
  411. e,
  412. )
  413. _LOGGER.warning(
  414. "Device matches %s with quality of %d%%. LOCAL DPS: %s",
  415. best_matching_type,
  416. best_match,
  417. log_json(dps),
  418. )
  419. _LOGGER.warning(
  420. "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
  421. )
  422. if types:
  423. detected = getattr(self, "_auto_detected_protocol", None)
  424. schema = vol.Schema(
  425. {
  426. vol.Required(
  427. CONF_TYPE,
  428. default=best_matching_type,
  429. ): vol.In(types),
  430. }
  431. )
  432. if detected:
  433. return self.async_show_form(
  434. step_id="select_type_auto_detected",
  435. data_schema=schema,
  436. description_placeholders={"detected_protocol": str(detected)},
  437. )
  438. return self.async_show_form(
  439. step_id="select_type",
  440. data_schema=schema,
  441. )
  442. else:
  443. return self.async_abort(reason="not_supported")
  444. async def async_step_select_type_auto_detected(self, user_input=None):
  445. return await self.async_step_select_type(user_input)
  446. async def async_step_choose_entities(self, user_input=None):
  447. if user_input is not None:
  448. title = user_input[CONF_NAME]
  449. del user_input[CONF_NAME]
  450. return self.async_create_entry(
  451. title=title, data={**self.data, **user_input}
  452. )
  453. config = await self.hass.async_add_executor_job(
  454. get_config,
  455. self.data[CONF_TYPE],
  456. )
  457. schema = {vol.Required(CONF_NAME, default=config.name): str}
  458. return self.async_show_form(
  459. step_id="choose_entities",
  460. data_schema=vol.Schema(schema),
  461. )
  462. @staticmethod
  463. @callback
  464. def async_get_options_flow(config_entry: ConfigEntry):
  465. return OptionsFlowHandler()
  466. class OptionsFlowHandler(OptionsFlow):
  467. def __init__(self):
  468. """Initialize options flow."""
  469. pass
  470. async def async_step_init(self, user_input=None):
  471. return await self.async_step_user(user_input)
  472. async def async_step_user(self, user_input=None):
  473. """Manage the options."""
  474. errors = {}
  475. config = {**self.config_entry.data, **self.config_entry.options}
  476. if user_input is not None:
  477. config = {**config, **user_input}
  478. device = await async_test_connection(config, self.hass)
  479. if device:
  480. return self.async_create_entry(title="", data=user_input)
  481. else:
  482. errors["base"] = "connection"
  483. schema = {
  484. vol.Required(
  485. CONF_LOCAL_KEY,
  486. default=config.get(CONF_LOCAL_KEY, ""),
  487. ): str,
  488. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  489. vol.Required(
  490. CONF_PROTOCOL_VERSION,
  491. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  492. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  493. vol.Required(
  494. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  495. ): bool,
  496. }
  497. cfg = await self.hass.async_add_executor_job(
  498. get_config,
  499. config[CONF_TYPE],
  500. )
  501. if cfg is None:
  502. return self.async_abort(reason="not_supported")
  503. return self.async_show_form(
  504. step_id="user",
  505. data_schema=vol.Schema(schema),
  506. description_placeholders={"device_details_url": DEVICE_DETAILS_URL},
  507. errors=errors,
  508. )
  509. def create_test_device(hass: HomeAssistant, config: dict):
  510. """Set up a tuya device based on passed in config."""
  511. subdevice_id = config.get(CONF_DEVICE_CID)
  512. device = TuyaLocalDevice(
  513. "Test",
  514. config[CONF_DEVICE_ID],
  515. config[CONF_HOST],
  516. config[CONF_LOCAL_KEY],
  517. config[CONF_PROTOCOL_VERSION],
  518. subdevice_id,
  519. hass,
  520. True,
  521. )
  522. return device
  523. async def async_test_connection(config: dict, hass: HomeAssistant):
  524. domain_data = hass.data.get(DOMAIN)
  525. existing = domain_data.get(get_device_id(config)) if domain_data else None
  526. if existing and existing.get("device"):
  527. _LOGGER.info("Pausing existing device to test new connection parameters")
  528. existing["device"].pause()
  529. await asyncio.sleep(5)
  530. retval = None
  531. if config.get(CONF_PROTOCOL_VERSION) == "auto":
  532. # Test each protocol with a fresh device object. Reusing one device
  533. # object across protocol rotations causes 3.4/3.5 handshakes to fail:
  534. # the shared tinytuya object carries stale internal state from the
  535. # prior connection attempts.
  536. for proto in API_PROTOCOL_VERSIONS:
  537. proto_config = {**config, CONF_PROTOCOL_VERSION: proto}
  538. device = None
  539. try:
  540. device = await hass.async_add_executor_job(
  541. create_test_device, hass, proto_config
  542. )
  543. await device.async_refresh()
  544. if device.has_returned_state:
  545. retval = device
  546. break
  547. except Exception as e:
  548. _LOGGER.debug("Protocol %s test failed with %s %s", proto, type(e), e)
  549. if device is not None:
  550. device._api.set_socketPersistent(False)
  551. if device._api.parent:
  552. device._api.parent.set_socketPersistent(False)
  553. else:
  554. try:
  555. device = await hass.async_add_executor_job(
  556. create_test_device,
  557. hass,
  558. config,
  559. )
  560. await device.async_refresh()
  561. retval = device if device.has_returned_state else None
  562. except Exception as e:
  563. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  564. if existing and existing.get("device"):
  565. _LOGGER.info("Restarting device after test")
  566. existing["device"].resume()
  567. return retval
  568. def scan_for_device(id):
  569. return tinytuya.find_device(dev_id=id)