config_flow.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. import asyncio
  2. import logging
  3. from collections import OrderedDict
  4. from typing import Any
  5. import tinytuya
  6. import voluptuous as vol
  7. from homeassistant import config_entries
  8. from homeassistant.const import CONF_HOST, CONF_NAME
  9. from homeassistant.core import HomeAssistant, callback
  10. from homeassistant.data_entry_flow import FlowResult
  11. from homeassistant.helpers.selector import (
  12. QrCodeSelector,
  13. QrCodeSelectorConfig,
  14. QrErrorCorrectionLevel,
  15. SelectOptionDict,
  16. SelectSelector,
  17. SelectSelectorConfig,
  18. SelectSelectorMode,
  19. )
  20. from tuya_sharing import (
  21. CustomerDevice,
  22. LoginControl,
  23. Manager,
  24. SharingDeviceListener,
  25. SharingTokenListener,
  26. )
  27. from . import DOMAIN
  28. from .const import (
  29. API_PROTOCOL_VERSIONS,
  30. CONF_DEVICE_CID,
  31. CONF_DEVICE_ID,
  32. CONF_ENDPOINT,
  33. CONF_LOCAL_KEY,
  34. CONF_POLL_ONLY,
  35. CONF_PROTOCOL_VERSION,
  36. CONF_TERMINAL_ID,
  37. CONF_TYPE,
  38. CONF_USER_CODE,
  39. DATA_STORE,
  40. TUYA_CLIENT_ID,
  41. TUYA_RESPONSE_CODE,
  42. TUYA_RESPONSE_MSG,
  43. TUYA_RESPONSE_QR_CODE,
  44. TUYA_RESPONSE_RESULT,
  45. TUYA_RESPONSE_SUCCESS,
  46. TUYA_SCHEMA,
  47. )
  48. from .device import TuyaLocalDevice
  49. from .helpers.config import get_device_id
  50. from .helpers.device_config import get_config
  51. from .helpers.log import log_json
  52. _LOGGER = logging.getLogger(__name__)
  53. class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
  54. VERSION = 13
  55. MINOR_VERSION = 3
  56. CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
  57. device = None
  58. data = {}
  59. __user_code: str
  60. __qr_code: str
  61. __authentication: dict
  62. __cloud_devices: dict
  63. __cloud_device: dict
  64. def __init__(self) -> None:
  65. """Initialize the config flow."""
  66. self.__login_control = LoginControl()
  67. self.__cloud_devices = {}
  68. self.__cloud_device = None
  69. async def async_step_user(self, user_input=None):
  70. errors = {}
  71. if self.hass.data.get(DOMAIN) is None:
  72. self.hass.data[DOMAIN] = {}
  73. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  74. self.hass.data[DOMAIN][DATA_STORE] = {}
  75. self.__authentication = self.hass.data[DOMAIN][DATA_STORE].get(
  76. "authentication", None
  77. )
  78. if user_input is not None:
  79. if user_input["setup_mode"] == "cloud":
  80. try:
  81. if self.__authentication is not None:
  82. self.__cloud_devices = await self.load_device_info()
  83. return await self.async_step_choose_device(None)
  84. except Exception as e:
  85. # Re-authentication is needed.
  86. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  87. _LOGGER.warning("Re-authentication is required.")
  88. return await self.async_step_cloud(None)
  89. if user_input["setup_mode"] == "manual":
  90. return await self.async_step_local(None)
  91. # Build form
  92. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  93. fields[vol.Required("setup_mode")] = SelectSelector(
  94. SelectSelectorConfig(
  95. options=["cloud", "manual"],
  96. mode=SelectSelectorMode.LIST,
  97. translation_key="setup_mode",
  98. )
  99. )
  100. return self.async_show_form(
  101. step_id="user",
  102. data_schema=vol.Schema(fields),
  103. errors=errors or {},
  104. last_step=False,
  105. )
  106. async def async_step_cloud(
  107. self, user_input: dict[str, Any] | None = None
  108. ) -> FlowResult:
  109. """Step user."""
  110. errors = {}
  111. placeholders = {}
  112. if user_input is not None:
  113. success, response = await self.__async_get_qr_code(
  114. user_input[CONF_USER_CODE]
  115. )
  116. if success:
  117. return await self.async_step_scan(None)
  118. errors["base"] = "login_error"
  119. placeholders = {
  120. TUYA_RESPONSE_MSG: response.get(TUYA_RESPONSE_MSG, "Unknown error"),
  121. TUYA_RESPONSE_CODE: response.get(TUYA_RESPONSE_CODE, "0"),
  122. }
  123. else:
  124. user_input = {}
  125. return self.async_show_form(
  126. step_id="cloud",
  127. data_schema=vol.Schema(
  128. {
  129. vol.Required(
  130. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  131. ): str,
  132. }
  133. ),
  134. errors=errors,
  135. description_placeholders=placeholders,
  136. )
  137. async def async_step_scan(
  138. self, user_input: dict[str, Any] | None = None
  139. ) -> FlowResult:
  140. """Step scan."""
  141. if user_input is None:
  142. return self.async_show_form(
  143. step_id="scan",
  144. data_schema=vol.Schema(
  145. {
  146. vol.Optional("QR"): QrCodeSelector(
  147. config=QrCodeSelectorConfig(
  148. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  149. scale=5,
  150. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  151. )
  152. )
  153. }
  154. ),
  155. )
  156. ret, info = await self.hass.async_add_executor_job(
  157. self.__login_control.login_result,
  158. self.__qr_code,
  159. TUYA_CLIENT_ID,
  160. self.__user_code,
  161. )
  162. if not ret:
  163. # Try to get a new QR code on failure
  164. await self.__async_get_qr_code(self.__user_code)
  165. return self.async_show_form(
  166. step_id="scan",
  167. errors={"base": "login_error"},
  168. data_schema=vol.Schema(
  169. {
  170. vol.Optional("QR"): QrCodeSelector(
  171. config=QrCodeSelectorConfig(
  172. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  173. scale=5,
  174. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  175. )
  176. )
  177. }
  178. ),
  179. description_placeholders={
  180. TUYA_RESPONSE_MSG: info.get(TUYA_RESPONSE_MSG, "Unknown error"),
  181. TUYA_RESPONSE_CODE: info.get(TUYA_RESPONSE_CODE, 0),
  182. },
  183. )
  184. # Now that we have successfully logged in we can query for devices for the account.
  185. self.__authentication = {
  186. "user_code": info[CONF_TERMINAL_ID],
  187. "terminal_id": info[CONF_TERMINAL_ID],
  188. "endpoint": info[CONF_ENDPOINT],
  189. "token_info": {
  190. "t": info["t"],
  191. "uid": info["uid"],
  192. "expire_time": info["expire_time"],
  193. "access_token": info["access_token"],
  194. "refresh_token": info["refresh_token"],
  195. },
  196. }
  197. self.hass.data[DOMAIN][DATA_STORE]["authentication"] = self.__authentication
  198. _LOGGER.debug(f"domain_data is {self.hass.data[DOMAIN]}")
  199. self.__cloud_devices = await self.load_device_info()
  200. return await self.async_step_choose_device(None)
  201. async def load_device_info(self) -> dict:
  202. token_listener = TokenListener(self.hass)
  203. manager = Manager(
  204. TUYA_CLIENT_ID,
  205. self.__authentication["user_code"],
  206. self.__authentication["terminal_id"],
  207. self.__authentication["endpoint"],
  208. self.__authentication["token_info"],
  209. token_listener,
  210. )
  211. listener = DeviceListener(self.hass, manager)
  212. manager.add_device_listener(listener)
  213. # Get all devices from Tuya
  214. await self.hass.async_add_executor_job(manager.update_device_cache)
  215. # Register known device IDs
  216. cloud_devices = {}
  217. domain_data = self.hass.data.get(DOMAIN)
  218. for device in manager.device_map.values():
  219. cloud_device = {
  220. # TODO - Use constants throughout
  221. "category": device.category,
  222. "id": device.id,
  223. "ip": device.ip, # This will be the WAN IP address so not usable.
  224. CONF_LOCAL_KEY: device.local_key
  225. if hasattr(device, CONF_LOCAL_KEY)
  226. else "",
  227. "name": device.name,
  228. "node_id": device.node_id if hasattr(device, "node_id") else "",
  229. "online": device.online,
  230. "product_id": device.product_id,
  231. "product_name": device.product_name,
  232. "uid": device.uid,
  233. "uuid": device.uuid,
  234. "support_local": device.support_local, # What does this mean?
  235. CONF_DEVICE_CID: None,
  236. "version": None,
  237. }
  238. _LOGGER.debug(f"Found device: {cloud_device}")
  239. existing_id = domain_data.get(cloud_device["id"]) if domain_data else None
  240. existing_uuid = (
  241. domain_data.get(cloud_device["uuid"]) if domain_data else None
  242. )
  243. if existing_id or existing_uuid:
  244. _LOGGER.debug("Device is already registered.")
  245. continue
  246. _LOGGER.debug(f"Adding device: {cloud_device['id']}")
  247. cloud_devices[cloud_device["id"]] = cloud_device
  248. return cloud_devices
  249. async def async_step_choose_device(self, user_input=None):
  250. errors = {}
  251. if user_input is not None:
  252. device_choice = self.__cloud_devices[user_input["device_id"]]
  253. if device_choice["ip"] != "":
  254. # This is a directly addable device.
  255. if user_input["hub_id"] == "None":
  256. device_choice["ip"] = ""
  257. self.__cloud_device = device_choice
  258. return await self.async_step_search(None)
  259. else:
  260. # Show error if user selected a hub.
  261. errors["base"] = "does_not_need_hub"
  262. # Fall through to reshow the form.
  263. else:
  264. # This is an indirectly addressable device. Need to know which hub it is connected to.
  265. if user_input["hub_id"] != "None":
  266. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  267. # Populate uuid and local_key from the child device to pass on complete information to the local step.
  268. hub_choice["ip"] = ""
  269. hub_choice[CONF_DEVICE_CID] = device_choice["uuid"]
  270. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  271. self.__cloud_device = hub_choice
  272. return await self.async_step_search(None)
  273. else:
  274. # Show error if user did not select a hub.
  275. errors["base"] = "needs_hub"
  276. # Fall through to reshow the form.
  277. device_list = []
  278. for key in self.__cloud_devices.keys():
  279. device_entry = self.__cloud_devices[key]
  280. if device_entry[CONF_LOCAL_KEY] != "":
  281. if device_entry["online"]:
  282. device_list.append(
  283. SelectOptionDict(
  284. value=key,
  285. label=f"{device_entry['name']} ({device_entry['product_name']})",
  286. )
  287. )
  288. else:
  289. device_list.append(
  290. SelectOptionDict(
  291. value=key,
  292. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  293. )
  294. )
  295. _LOGGER.debug(f"Device count: {len(device_list)}")
  296. if len(device_list) == 0:
  297. return self.async_abort(reason="no_devices")
  298. device_selector = SelectSelector(
  299. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  300. )
  301. hub_list = []
  302. hub_list.append(SelectOptionDict(value="None", label="None"))
  303. for key in self.__cloud_devices.keys():
  304. hub_entry = self.__cloud_devices[key]
  305. if hub_entry[CONF_LOCAL_KEY] == "":
  306. hub_list.append(
  307. SelectOptionDict(
  308. value=key,
  309. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  310. )
  311. )
  312. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  313. hub_selector = SelectSelector(
  314. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  315. )
  316. # Build form
  317. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  318. fields[vol.Required("device_id")] = device_selector
  319. fields[vol.Required("hub_id")] = hub_selector
  320. return self.async_show_form(
  321. step_id="choose_device",
  322. data_schema=vol.Schema(fields),
  323. errors=errors or {},
  324. last_step=False,
  325. )
  326. async def async_step_search(self, user_input=None):
  327. if user_input is not None:
  328. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  329. # This scan will take 18s with the default settings. If we cannot find the device we
  330. # will just leave the IP address blank and hope the user can discover the IP by other
  331. # means such as router device IP assignments.
  332. _LOGGER.debug(
  333. f"Scanning network to get IP address for {self.__cloud_device['id']}."
  334. )
  335. self.__cloud_device["ip"] = ""
  336. local_device = await self.hass.async_add_executor_job(
  337. scan_for_device, self.__cloud_device["id"]
  338. )
  339. if local_device["ip"] is not None:
  340. _LOGGER.debug(f"Found: {local_device}")
  341. self.__cloud_device["ip"] = local_device["ip"]
  342. self.__cloud_device["version"] = local_device["version"]
  343. else:
  344. _LOGGER.warn(f"Could not find device: {self.__cloud_device['id']}")
  345. return await self.async_step_local(None)
  346. return self.async_show_form(
  347. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  348. )
  349. async def async_step_local(self, user_input=None):
  350. errors = {}
  351. devid_opts = {}
  352. host_opts = {"default": ""}
  353. key_opts = {}
  354. proto_opts = {"default": 3.3}
  355. polling_opts = {"default": False}
  356. devcid_opts = {}
  357. if self.__cloud_device is not None:
  358. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  359. devid_opts = {"default": self.__cloud_device["id"]}
  360. host_opts = {"default": self.__cloud_device["ip"]}
  361. key_opts = {"default": self.__cloud_device[CONF_LOCAL_KEY]}
  362. if self.__cloud_device["version"] is not None:
  363. proto_opts = {"default": float(self.__cloud_device["version"])}
  364. if self.__cloud_device[CONF_DEVICE_CID] is not None:
  365. devcid_opts = {"default": self.__cloud_device[CONF_DEVICE_CID]}
  366. if user_input is not None:
  367. self.device = await async_test_connection(user_input, self.hass)
  368. if self.device:
  369. self.data = user_input
  370. return await self.async_step_select_type()
  371. else:
  372. errors["base"] = "connection"
  373. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  374. host_opts["default"] = user_input[CONF_HOST]
  375. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  376. if CONF_DEVICE_CID in user_input:
  377. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  378. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  379. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  380. return self.async_show_form(
  381. step_id="local",
  382. data_schema=vol.Schema(
  383. {
  384. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  385. vol.Required(CONF_HOST, **host_opts): str,
  386. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  387. vol.Required(
  388. CONF_PROTOCOL_VERSION,
  389. **proto_opts,
  390. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  391. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  392. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  393. }
  394. ),
  395. errors=errors,
  396. )
  397. async def async_step_select_type(self, user_input=None):
  398. if user_input is not None:
  399. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  400. return await self.async_step_choose_entities()
  401. types = []
  402. best_match = 0
  403. best_matching_type = None
  404. async for type in self.device.async_possible_types():
  405. types.append(type.config_type)
  406. q = type.match_quality(self.device._get_cached_state())
  407. if q > best_match:
  408. best_match = q
  409. best_matching_type = type.config_type
  410. best_match = int(best_match)
  411. dps = self.device._get_cached_state()
  412. _LOGGER.warning(
  413. "Device matches %s with quality of %d%%. DPS: %s",
  414. best_matching_type,
  415. best_match,
  416. log_json(dps),
  417. )
  418. _LOGGER.warning(
  419. "Report this to https://github.com/make-all/tuya-local/issues/",
  420. )
  421. if types:
  422. return self.async_show_form(
  423. step_id="select_type",
  424. data_schema=vol.Schema(
  425. {
  426. vol.Required(
  427. CONF_TYPE,
  428. default=best_matching_type,
  429. ): vol.In(types),
  430. }
  431. ),
  432. )
  433. else:
  434. return self.async_abort(reason="not_supported")
  435. async def async_step_choose_entities(self, user_input=None):
  436. if user_input is not None:
  437. title = user_input[CONF_NAME]
  438. del user_input[CONF_NAME]
  439. return self.async_create_entry(
  440. title=title, data={**self.data, **user_input}
  441. )
  442. config = get_config(self.data[CONF_TYPE])
  443. schema = {vol.Required(CONF_NAME, default=config.name): str}
  444. return self.async_show_form(
  445. step_id="choose_entities",
  446. data_schema=vol.Schema(schema),
  447. )
  448. @staticmethod
  449. @callback
  450. def async_get_options_flow(config_entry):
  451. return OptionsFlowHandler(config_entry)
  452. async def __async_get_qr_code(self, user_code: str) -> tuple[bool, dict[str, Any]]:
  453. """Get the QR code."""
  454. response = await self.hass.async_add_executor_job(
  455. self.__login_control.qr_code,
  456. TUYA_CLIENT_ID,
  457. TUYA_SCHEMA,
  458. user_code,
  459. )
  460. if success := response.get(TUYA_RESPONSE_SUCCESS, False):
  461. self.__user_code = user_code
  462. self.__qr_code = response[TUYA_RESPONSE_RESULT][TUYA_RESPONSE_QR_CODE]
  463. return success, response
  464. class OptionsFlowHandler(config_entries.OptionsFlow):
  465. def __init__(self, config_entry):
  466. """Initialize options flow."""
  467. self.config_entry = config_entry
  468. async def async_step_init(self, user_input=None):
  469. return await self.async_step_user(user_input)
  470. async def async_step_user(self, user_input=None):
  471. """Manage the options."""
  472. errors = {}
  473. config = {**self.config_entry.data, **self.config_entry.options}
  474. if user_input is not None:
  475. config = {**config, **user_input}
  476. device = await async_test_connection(config, self.hass)
  477. if device:
  478. return self.async_create_entry(title="", data=user_input)
  479. else:
  480. errors["base"] = "connection"
  481. schema = {
  482. vol.Required(
  483. CONF_LOCAL_KEY,
  484. default=config.get(CONF_LOCAL_KEY, ""),
  485. ): str,
  486. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  487. vol.Required(
  488. CONF_PROTOCOL_VERSION,
  489. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  490. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  491. vol.Required(
  492. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  493. ): bool,
  494. vol.Optional(
  495. CONF_DEVICE_CID,
  496. default=config.get(CONF_DEVICE_CID, ""),
  497. ): str,
  498. }
  499. cfg = get_config(config[CONF_TYPE])
  500. if cfg is None:
  501. return self.async_abort(reason="not_supported")
  502. return self.async_show_form(
  503. step_id="user",
  504. data_schema=vol.Schema(schema),
  505. errors=errors,
  506. )
  507. async def async_test_connection(config: dict, hass: HomeAssistant):
  508. domain_data = hass.data.get(DOMAIN)
  509. existing = domain_data.get(get_device_id(config)) if domain_data else None
  510. if existing:
  511. _LOGGER.info("Pausing existing device to test new connection parameters")
  512. existing["device"].pause()
  513. await asyncio.sleep(5)
  514. try:
  515. subdevice_id = config.get(CONF_DEVICE_CID)
  516. device = TuyaLocalDevice(
  517. "Test",
  518. config[CONF_DEVICE_ID],
  519. config[CONF_HOST],
  520. config[CONF_LOCAL_KEY],
  521. config[CONF_PROTOCOL_VERSION],
  522. subdevice_id,
  523. hass,
  524. True,
  525. )
  526. await device.async_refresh()
  527. retval = device if device.has_returned_state else None
  528. except Exception as e:
  529. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  530. retval = None
  531. if existing:
  532. _LOGGER.info("Restarting device after test")
  533. existing["device"].resume()
  534. return retval
  535. def scan_for_device(id):
  536. return tinytuya.find_device(dev_id=id)
  537. class DeviceListener(SharingDeviceListener):
  538. """Device Update Listener."""
  539. def __init__(
  540. self,
  541. hass: HomeAssistant,
  542. manager: Manager,
  543. ) -> None:
  544. """Init DeviceListener."""
  545. self.hass = hass
  546. self.manager = manager
  547. def update_device(self, device: CustomerDevice) -> None:
  548. """Update device status."""
  549. _LOGGER.debug(
  550. "Received update for device %s: %s",
  551. device.id,
  552. self.manager.device_map[device.id].status,
  553. )
  554. def add_device(self, device: CustomerDevice) -> None:
  555. """Add device added listener."""
  556. _LOGGER.debug(
  557. "Received add device %s: %s",
  558. device.id,
  559. self.manager.device_map[device.id].status,
  560. )
  561. def remove_device(self, device_id: str) -> None:
  562. """Add device removed listener."""
  563. _LOGGER.debug(
  564. "Received remove device %s: %s",
  565. device_id,
  566. self.manager.device_map[device_id].status,
  567. )
  568. class TokenListener(SharingTokenListener):
  569. """Token listener for upstream token updates."""
  570. def __init__(
  571. self,
  572. hass: HomeAssistant,
  573. ) -> None:
  574. """Init TokenListener."""
  575. self.hass = hass
  576. def update_token(self, token_info: dict[str, Any]) -> None:
  577. """Update token info in config entry."""
  578. _LOGGER.debug("update_token")