config_flow.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. import asyncio
  2. import logging
  3. from collections import OrderedDict
  4. from typing import Any
  5. import tinytuya
  6. import voluptuous as vol
  7. from homeassistant.config_entries import (
  8. CONN_CLASS_LOCAL_PUSH,
  9. ConfigEntry,
  10. ConfigFlow,
  11. OptionsFlow,
  12. )
  13. from homeassistant.const import CONF_HOST, CONF_NAME
  14. from homeassistant.core import HomeAssistant, callback
  15. from homeassistant.data_entry_flow import FlowResult
  16. from homeassistant.helpers.selector import (
  17. QrCodeSelector,
  18. QrCodeSelectorConfig,
  19. QrErrorCorrectionLevel,
  20. SelectOptionDict,
  21. SelectSelector,
  22. SelectSelectorConfig,
  23. SelectSelectorMode,
  24. )
  25. from . import DOMAIN
  26. from .cloud import Cloud
  27. from .const import (
  28. API_PROTOCOL_VERSIONS,
  29. CONF_DEVICE_CID,
  30. CONF_DEVICE_ID,
  31. CONF_LOCAL_KEY,
  32. CONF_POLL_ONLY,
  33. CONF_PROTOCOL_VERSION,
  34. CONF_TYPE,
  35. CONF_USER_CODE,
  36. DATA_STORE,
  37. )
  38. from .device import TuyaLocalDevice
  39. from .helpers.config import get_device_id
  40. from .helpers.device_config import get_config
  41. from .helpers.log import log_json
  42. _LOGGER = logging.getLogger(__name__)
  43. class ConfigFlowHandler(ConfigFlow, domain=DOMAIN):
  44. VERSION = 13
  45. MINOR_VERSION = 9
  46. CONNECTION_CLASS = CONN_CLASS_LOCAL_PUSH
  47. device = None
  48. data = {}
  49. __qr_code: str | None = None
  50. __cloud_devices: dict[str, Any] = {}
  51. __cloud_device: dict[str, Any] | None = None
  52. def __init__(self) -> None:
  53. """Initialize the config flow."""
  54. self.cloud = None
  55. def init_cloud(self):
  56. if self.cloud is None:
  57. self.cloud = Cloud(self.hass)
  58. async def async_step_user(self, user_input=None):
  59. errors = {}
  60. if self.hass.data.get(DOMAIN) is None:
  61. self.hass.data[DOMAIN] = {}
  62. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  63. self.hass.data[DOMAIN][DATA_STORE] = {}
  64. if user_input is not None:
  65. if user_input["setup_mode"] == "cloud":
  66. self.init_cloud()
  67. try:
  68. if self.cloud.is_authenticated:
  69. self.__cloud_devices = await self.cloud.async_get_devices()
  70. return await self.async_step_choose_device()
  71. except Exception as e:
  72. # Re-authentication is needed.
  73. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  74. _LOGGER.warning("Re-authentication is required.")
  75. return await self.async_step_cloud()
  76. if user_input["setup_mode"] == "manual":
  77. return await self.async_step_local()
  78. # Build form
  79. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  80. fields[vol.Required("setup_mode")] = SelectSelector(
  81. SelectSelectorConfig(
  82. options=["cloud", "manual"],
  83. mode=SelectSelectorMode.LIST,
  84. translation_key="setup_mode",
  85. )
  86. )
  87. return self.async_show_form(
  88. step_id="user",
  89. data_schema=vol.Schema(fields),
  90. errors=errors or {},
  91. last_step=False,
  92. )
  93. async def async_step_cloud(
  94. self, user_input: dict[str, Any] | None = None
  95. ) -> FlowResult:
  96. """Step user."""
  97. errors = {}
  98. placeholders = {}
  99. self.init_cloud()
  100. if user_input is not None:
  101. response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
  102. if response:
  103. self.__qr_code = response
  104. return await self.async_step_scan()
  105. errors["base"] = "login_error"
  106. placeholders = self.cloud.last_error
  107. else:
  108. user_input = {}
  109. return self.async_show_form(
  110. step_id="cloud",
  111. data_schema=vol.Schema(
  112. {
  113. vol.Required(
  114. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  115. ): str,
  116. }
  117. ),
  118. errors=errors,
  119. description_placeholders=placeholders,
  120. )
  121. async def async_step_scan(
  122. self, user_input: dict[str, Any] | None = None
  123. ) -> FlowResult:
  124. """Step scan."""
  125. if user_input is None:
  126. return self.async_show_form(
  127. step_id="scan",
  128. data_schema=vol.Schema(
  129. {
  130. vol.Optional("QR"): QrCodeSelector(
  131. config=QrCodeSelectorConfig(
  132. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  133. scale=5,
  134. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  135. )
  136. )
  137. }
  138. ),
  139. )
  140. self.init_cloud()
  141. if not await self.cloud.async_login():
  142. # Try to get a new QR code on failure
  143. response = await self.cloud.async_get_qr_code()
  144. errors = {"base": "login_error"}
  145. placeholders = self.cloud.last_error
  146. if response:
  147. self.__qr_code = response
  148. return self.async_show_form(
  149. step_id="scan",
  150. errors=errors,
  151. data_schema=vol.Schema(
  152. {
  153. vol.Optional("QR"): QrCodeSelector(
  154. config=QrCodeSelectorConfig(
  155. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  156. scale=5,
  157. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  158. )
  159. )
  160. }
  161. ),
  162. description_placeholders=placeholders,
  163. )
  164. self.__cloud_devices = await self.cloud.async_get_devices()
  165. return await self.async_step_choose_device()
  166. async def async_step_choose_device(self, user_input=None):
  167. errors = {}
  168. if user_input is not None:
  169. device_choice = self.__cloud_devices[user_input["device_id"]]
  170. if device_choice["ip"] != "":
  171. # This is a directly addable device.
  172. if user_input["hub_id"] == "None":
  173. device_choice["ip"] = ""
  174. self.__cloud_device = device_choice
  175. return await self.async_step_search()
  176. else:
  177. # Show error if user selected a hub.
  178. errors["base"] = "does_not_need_hub"
  179. # Fall through to reshow the form.
  180. else:
  181. # This is an indirectly addressable device. Need to know which hub it is connected to.
  182. if user_input["hub_id"] != "None":
  183. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  184. # Populate node_id or uuid and local_key from the child
  185. # device to pass on complete information to the local step.
  186. hub_choice["ip"] = ""
  187. hub_choice[CONF_DEVICE_CID] = (
  188. device_choice["node_id"] or device_choice["uuid"]
  189. )
  190. if device_choice.get(CONF_LOCAL_KEY):
  191. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  192. # Communicate the sub device product id to help match the
  193. # correect device config in the next step.
  194. hub_choice["product_id"] = device_choice["product_id"]
  195. self.__cloud_device = hub_choice
  196. return await self.async_step_search()
  197. else:
  198. # Show error if user did not select a hub.
  199. errors["base"] = "needs_hub"
  200. # Fall through to reshow the form.
  201. device_list = []
  202. for key in self.__cloud_devices.keys():
  203. device_entry = self.__cloud_devices[key]
  204. if device_entry.get("exists"):
  205. continue
  206. if device_entry[CONF_LOCAL_KEY] != "":
  207. if device_entry["online"]:
  208. device_list.append(
  209. SelectOptionDict(
  210. value=key,
  211. label=f"{device_entry['name']} ({device_entry['product_name']})",
  212. )
  213. )
  214. else:
  215. device_list.append(
  216. SelectOptionDict(
  217. value=key,
  218. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  219. )
  220. )
  221. _LOGGER.debug(f"Device count: {len(device_list)}")
  222. if len(device_list) == 0:
  223. return self.async_abort(reason="no_devices")
  224. device_selector = SelectSelector(
  225. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  226. )
  227. hub_list = []
  228. hub_list.append(SelectOptionDict(value="None", label="None"))
  229. for key in self.__cloud_devices.keys():
  230. hub_entry = self.__cloud_devices[key]
  231. if hub_entry["is_hub"]:
  232. hub_list.append(
  233. SelectOptionDict(
  234. value=key,
  235. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  236. )
  237. )
  238. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  239. hub_selector = SelectSelector(
  240. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  241. )
  242. # Build form
  243. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  244. fields[vol.Required("device_id")] = device_selector
  245. fields[vol.Required("hub_id")] = hub_selector
  246. return self.async_show_form(
  247. step_id="choose_device",
  248. data_schema=vol.Schema(fields),
  249. errors=errors or {},
  250. last_step=False,
  251. )
  252. async def async_step_search(self, user_input=None):
  253. if user_input is not None:
  254. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  255. # This scan will take 18s with the default settings. If we cannot find the device we
  256. # will just leave the IP address blank and hope the user can discover the IP by other
  257. # means such as router device IP assignments.
  258. _LOGGER.debug(
  259. f"Scanning network to get IP address for {self.__cloud_device.get('id', 'DEVICE_KEY_UNAVAILABLE')}."
  260. )
  261. self.__cloud_device["ip"] = ""
  262. try:
  263. local_device = await self.hass.async_add_executor_job(
  264. scan_for_device, self.__cloud_device.get("id")
  265. )
  266. except OSError:
  267. local_device = {"ip": None, "version": ""}
  268. if local_device.get("ip"):
  269. _LOGGER.debug(f"Found: {local_device}")
  270. self.__cloud_device["ip"] = local_device.get("ip")
  271. self.__cloud_device["version"] = local_device.get("version")
  272. if not self.__cloud_device.get(CONF_DEVICE_CID):
  273. self.__cloud_device["local_product_id"] = local_device.get(
  274. "productKey"
  275. )
  276. else:
  277. _LOGGER.warning(
  278. f"Could not find device: {self.__cloud_device.get('id', 'DEVICE_KEY_UNAVAILABLE')}"
  279. )
  280. return await self.async_step_local()
  281. return self.async_show_form(
  282. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  283. )
  284. async def async_step_local(self, user_input=None):
  285. errors = {}
  286. devid_opts = {}
  287. host_opts = {"default": ""}
  288. key_opts = {}
  289. proto_opts = {"default": 3.3}
  290. polling_opts = {"default": False}
  291. devcid_opts = {}
  292. if self.__cloud_device is not None:
  293. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  294. devid_opts = {"default": self.__cloud_device.get("id")}
  295. host_opts = {"default": self.__cloud_device.get("ip")}
  296. key_opts = {"default": self.__cloud_device.get(CONF_LOCAL_KEY)}
  297. if self.__cloud_device.get("version"):
  298. proto_opts = {"default": float(self.__cloud_device.get("version"))}
  299. if self.__cloud_device.get(CONF_DEVICE_CID):
  300. devcid_opts = {"default": self.__cloud_device.get(CONF_DEVICE_CID)}
  301. if user_input is not None:
  302. self.device = await async_test_connection(user_input, self.hass)
  303. if self.device:
  304. self.data = user_input
  305. if self.__cloud_device:
  306. if self.__cloud_device.get("product_id"):
  307. self.device.set_detected_product_id(
  308. self.__cloud_device.get("product_id")
  309. )
  310. if self.__cloud_device.get("local_product_id"):
  311. self.device.set_detected_product_id(
  312. self.__cloud_device.get("local_product_id")
  313. )
  314. return await self.async_step_select_type()
  315. else:
  316. errors["base"] = "connection"
  317. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  318. host_opts["default"] = user_input[CONF_HOST]
  319. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  320. if CONF_DEVICE_CID in user_input:
  321. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  322. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  323. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  324. return self.async_show_form(
  325. step_id="local",
  326. data_schema=vol.Schema(
  327. {
  328. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  329. vol.Required(CONF_HOST, **host_opts): str,
  330. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  331. vol.Required(
  332. CONF_PROTOCOL_VERSION,
  333. **proto_opts,
  334. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  335. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  336. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  337. }
  338. ),
  339. errors=errors,
  340. )
  341. async def async_step_select_type(self, user_input=None):
  342. if user_input is not None:
  343. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  344. return await self.async_step_choose_entities()
  345. types = []
  346. best_match = 0
  347. best_matching_type = None
  348. for type in await self.device.async_possible_types():
  349. types.append(type.config_type)
  350. q = type.match_quality(
  351. self.device._get_cached_state(),
  352. self.device._product_ids,
  353. )
  354. if q > best_match:
  355. best_match = q
  356. best_matching_type = type.config_type
  357. best_match = int(best_match)
  358. dps = self.device._get_cached_state()
  359. if self.__cloud_device:
  360. _LOGGER.warning(
  361. "Adding %s device with product id %s",
  362. self.__cloud_device.get("product_name", "UNKNOWN"),
  363. self.__cloud_device.get("product_id", "UNKNOWN"),
  364. )
  365. if self.__cloud_device.get("local_product_id") and self.__cloud_device.get(
  366. "local_product_id"
  367. ) != self.__cloud_device.get("product_id"):
  368. _LOGGER.warning(
  369. "Local product id differs from cloud: %s",
  370. self.__cloud_device.get("local_product_id"),
  371. )
  372. try:
  373. self.init_cloud()
  374. model = await self.cloud.async_get_datamodel(
  375. self.__cloud_device.get("id"),
  376. )
  377. if model:
  378. _LOGGER.warning(
  379. "Cloud device spec:\n%s",
  380. log_json(model),
  381. )
  382. except Exception as e:
  383. _LOGGER.warning("Unable to fetch data model from cloud: %s", e)
  384. _LOGGER.warning(
  385. "Device matches %s with quality of %d%%. DPS: %s",
  386. best_matching_type,
  387. best_match,
  388. log_json(dps),
  389. )
  390. _LOGGER.warning(
  391. "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
  392. )
  393. if types:
  394. return self.async_show_form(
  395. step_id="select_type",
  396. data_schema=vol.Schema(
  397. {
  398. vol.Required(
  399. CONF_TYPE,
  400. default=best_matching_type,
  401. ): vol.In(types),
  402. }
  403. ),
  404. )
  405. else:
  406. return self.async_abort(reason="not_supported")
  407. async def async_step_choose_entities(self, user_input=None):
  408. if user_input is not None:
  409. title = user_input[CONF_NAME]
  410. del user_input[CONF_NAME]
  411. return self.async_create_entry(
  412. title=title, data={**self.data, **user_input}
  413. )
  414. config = await self.hass.async_add_executor_job(
  415. get_config,
  416. self.data[CONF_TYPE],
  417. )
  418. schema = {vol.Required(CONF_NAME, default=config.name): str}
  419. return self.async_show_form(
  420. step_id="choose_entities",
  421. data_schema=vol.Schema(schema),
  422. )
  423. @staticmethod
  424. @callback
  425. def async_get_options_flow(config_entry: ConfigEntry):
  426. return OptionsFlowHandler()
  427. class OptionsFlowHandler(OptionsFlow):
  428. def __init__(self):
  429. """Initialize options flow."""
  430. pass
  431. async def async_step_init(self, user_input=None):
  432. return await self.async_step_user(user_input)
  433. async def async_step_user(self, user_input=None):
  434. """Manage the options."""
  435. errors = {}
  436. config = {**self.config_entry.data, **self.config_entry.options}
  437. if user_input is not None:
  438. config = {**config, **user_input}
  439. device = await async_test_connection(config, self.hass)
  440. if device:
  441. return self.async_create_entry(title="", data=user_input)
  442. else:
  443. errors["base"] = "connection"
  444. schema = {
  445. vol.Required(
  446. CONF_LOCAL_KEY,
  447. default=config.get(CONF_LOCAL_KEY, ""),
  448. ): str,
  449. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  450. vol.Required(
  451. CONF_PROTOCOL_VERSION,
  452. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  453. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  454. vol.Required(
  455. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  456. ): bool,
  457. }
  458. cfg = await self.hass.async_add_executor_job(
  459. get_config,
  460. config[CONF_TYPE],
  461. )
  462. if cfg is None:
  463. return self.async_abort(reason="not_supported")
  464. return self.async_show_form(
  465. step_id="user",
  466. data_schema=vol.Schema(schema),
  467. errors=errors,
  468. )
  469. def create_test_device(hass: HomeAssistant, config: dict):
  470. """Set up a tuya device based on passed in config."""
  471. subdevice_id = config.get(CONF_DEVICE_CID)
  472. device = TuyaLocalDevice(
  473. "Test",
  474. config[CONF_DEVICE_ID],
  475. config[CONF_HOST],
  476. config[CONF_LOCAL_KEY],
  477. config[CONF_PROTOCOL_VERSION],
  478. subdevice_id,
  479. hass,
  480. True,
  481. )
  482. return device
  483. async def async_test_connection(config: dict, hass: HomeAssistant):
  484. domain_data = hass.data.get(DOMAIN)
  485. existing = domain_data.get(get_device_id(config)) if domain_data else None
  486. if existing and existing.get("device"):
  487. _LOGGER.info("Pausing existing device to test new connection parameters")
  488. existing["device"].pause()
  489. await asyncio.sleep(5)
  490. try:
  491. device = await hass.async_add_executor_job(
  492. create_test_device,
  493. hass,
  494. config,
  495. )
  496. await device.async_refresh()
  497. retval = device if device.has_returned_state else None
  498. except Exception as e:
  499. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  500. retval = None
  501. if existing and existing.get("device"):
  502. _LOGGER.info("Restarting device after test")
  503. existing["device"].resume()
  504. return retval
  505. def scan_for_device(id):
  506. return tinytuya.find_device(dev_id=id)