config_flow.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. import asyncio
  2. import logging
  3. from collections import OrderedDict
  4. from typing import Any
  5. import tinytuya
  6. import voluptuous as vol
  7. from homeassistant import config_entries
  8. from homeassistant.const import CONF_HOST, CONF_NAME
  9. from homeassistant.core import HomeAssistant, callback
  10. from homeassistant.data_entry_flow import FlowResult
  11. from homeassistant.helpers.selector import (
  12. QrCodeSelector,
  13. QrCodeSelectorConfig,
  14. QrErrorCorrectionLevel,
  15. SelectOptionDict,
  16. SelectSelector,
  17. SelectSelectorConfig,
  18. SelectSelectorMode,
  19. )
  20. from . import DOMAIN
  21. from .cloud import Cloud
  22. from .const import (
  23. API_PROTOCOL_VERSIONS,
  24. CONF_DEVICE_CID,
  25. CONF_DEVICE_ID,
  26. CONF_LOCAL_KEY,
  27. CONF_POLL_ONLY,
  28. CONF_PROTOCOL_VERSION,
  29. CONF_TYPE,
  30. CONF_USER_CODE,
  31. DATA_STORE,
  32. )
  33. from .device import TuyaLocalDevice
  34. from .helpers.config import get_device_id
  35. from .helpers.device_config import get_config
  36. from .helpers.log import log_json
  37. _LOGGER = logging.getLogger(__name__)
  38. class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
  39. VERSION = 13
  40. MINOR_VERSION = 7
  41. CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
  42. device = None
  43. data = {}
  44. __qr_code: str | None = None
  45. __cloud_devices: dict[str, Any] = {}
  46. __cloud_device: dict[str, Any] | None = None
  47. def __init__(self) -> None:
  48. """Initialize the config flow."""
  49. self.cloud = None
  50. def init_cloud(self):
  51. if self.cloud is None:
  52. self.cloud = Cloud(self.hass)
  53. async def async_step_user(self, user_input=None):
  54. errors = {}
  55. if self.hass.data.get(DOMAIN) is None:
  56. self.hass.data[DOMAIN] = {}
  57. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  58. self.hass.data[DOMAIN][DATA_STORE] = {}
  59. if user_input is not None:
  60. if user_input["setup_mode"] == "cloud":
  61. self.init_cloud()
  62. try:
  63. if self.cloud.is_authenticated:
  64. self.__cloud_devices = await self.cloud.async_get_devices()
  65. return await self.async_step_choose_device()
  66. except Exception as e:
  67. # Re-authentication is needed.
  68. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  69. _LOGGER.warning("Re-authentication is required.")
  70. return await self.async_step_cloud()
  71. if user_input["setup_mode"] == "manual":
  72. return await self.async_step_local()
  73. # Build form
  74. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  75. fields[vol.Required("setup_mode")] = SelectSelector(
  76. SelectSelectorConfig(
  77. options=["cloud", "manual"],
  78. mode=SelectSelectorMode.LIST,
  79. translation_key="setup_mode",
  80. )
  81. )
  82. return self.async_show_form(
  83. step_id="user",
  84. data_schema=vol.Schema(fields),
  85. errors=errors or {},
  86. last_step=False,
  87. )
  88. async def async_step_cloud(
  89. self, user_input: dict[str, Any] | None = None
  90. ) -> FlowResult:
  91. """Step user."""
  92. errors = {}
  93. placeholders = {}
  94. self.init_cloud()
  95. if user_input is not None:
  96. response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
  97. if response:
  98. self.__qr_code = response
  99. return await self.async_step_scan()
  100. errors["base"] = "login_error"
  101. placeholders = self.cloud.last_error
  102. else:
  103. user_input = {}
  104. return self.async_show_form(
  105. step_id="cloud",
  106. data_schema=vol.Schema(
  107. {
  108. vol.Required(
  109. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  110. ): str,
  111. }
  112. ),
  113. errors=errors,
  114. description_placeholders=placeholders,
  115. )
  116. async def async_step_scan(
  117. self, user_input: dict[str, Any] | None = None
  118. ) -> FlowResult:
  119. """Step scan."""
  120. if user_input is None:
  121. return self.async_show_form(
  122. step_id="scan",
  123. data_schema=vol.Schema(
  124. {
  125. vol.Optional("QR"): QrCodeSelector(
  126. config=QrCodeSelectorConfig(
  127. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  128. scale=5,
  129. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  130. )
  131. )
  132. }
  133. ),
  134. )
  135. self.init_cloud()
  136. if not await self.cloud.async_login():
  137. # Try to get a new QR code on failure
  138. response = await self.cloud.async_get_qr_code()
  139. errors = {"base": "login_error"}
  140. placeholders = self.cloud.last_error
  141. if response:
  142. self.__qr_code = response
  143. return self.async_show_form(
  144. step_id="scan",
  145. errors=errors,
  146. data_schema=vol.Schema(
  147. {
  148. vol.Optional("QR"): QrCodeSelector(
  149. config=QrCodeSelectorConfig(
  150. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  151. scale=5,
  152. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  153. )
  154. )
  155. }
  156. ),
  157. description_placeholders=placeholders,
  158. )
  159. self.__cloud_devices = await self.cloud.async_get_devices()
  160. return await self.async_step_choose_device()
  161. async def async_step_choose_device(self, user_input=None):
  162. errors = {}
  163. if user_input is not None:
  164. device_choice = self.__cloud_devices[user_input["device_id"]]
  165. if device_choice["ip"] != "":
  166. # This is a directly addable device.
  167. if user_input["hub_id"] == "None":
  168. device_choice["ip"] = ""
  169. self.__cloud_device = device_choice
  170. return await self.async_step_search()
  171. else:
  172. # Show error if user selected a hub.
  173. errors["base"] = "does_not_need_hub"
  174. # Fall through to reshow the form.
  175. else:
  176. # This is an indirectly addressable device. Need to know which hub it is connected to.
  177. if user_input["hub_id"] != "None":
  178. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  179. # Populate node_id or uuid and local_key from the child
  180. # device to pass on complete information to the local step.
  181. hub_choice["ip"] = ""
  182. hub_choice[CONF_DEVICE_CID] = (
  183. device_choice["node_id"] or device_choice["uuid"]
  184. )
  185. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  186. self.__cloud_device = hub_choice
  187. return await self.async_step_search()
  188. else:
  189. # Show error if user did not select a hub.
  190. errors["base"] = "needs_hub"
  191. # Fall through to reshow the form.
  192. device_list = []
  193. for key in self.__cloud_devices.keys():
  194. device_entry = self.__cloud_devices[key]
  195. if device_entry.get("exists"):
  196. continue
  197. if device_entry[CONF_LOCAL_KEY] != "":
  198. if device_entry["online"]:
  199. device_list.append(
  200. SelectOptionDict(
  201. value=key,
  202. label=f"{device_entry['name']} ({device_entry['product_name']})",
  203. )
  204. )
  205. else:
  206. device_list.append(
  207. SelectOptionDict(
  208. value=key,
  209. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  210. )
  211. )
  212. _LOGGER.debug(f"Device count: {len(device_list)}")
  213. if len(device_list) == 0:
  214. return self.async_abort(reason="no_devices")
  215. device_selector = SelectSelector(
  216. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  217. )
  218. hub_list = []
  219. hub_list.append(SelectOptionDict(value="None", label="None"))
  220. for key in self.__cloud_devices.keys():
  221. hub_entry = self.__cloud_devices[key]
  222. if hub_entry["is_hub"]:
  223. hub_list.append(
  224. SelectOptionDict(
  225. value=key,
  226. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  227. )
  228. )
  229. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  230. hub_selector = SelectSelector(
  231. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  232. )
  233. # Build form
  234. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  235. fields[vol.Required("device_id")] = device_selector
  236. fields[vol.Required("hub_id")] = hub_selector
  237. return self.async_show_form(
  238. step_id="choose_device",
  239. data_schema=vol.Schema(fields),
  240. errors=errors or {},
  241. last_step=False,
  242. )
  243. async def async_step_search(self, user_input=None):
  244. if user_input is not None:
  245. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  246. # This scan will take 18s with the default settings. If we cannot find the device we
  247. # will just leave the IP address blank and hope the user can discover the IP by other
  248. # means such as router device IP assignments.
  249. _LOGGER.debug(
  250. f"Scanning network to get IP address for {self.__cloud_device['id']}."
  251. )
  252. self.__cloud_device["ip"] = ""
  253. try:
  254. local_device = await self.hass.async_add_executor_job(
  255. scan_for_device, self.__cloud_device["id"]
  256. )
  257. except OSError:
  258. local_device = {"ip": None, "version": ""}
  259. if local_device.get("ip"):
  260. _LOGGER.debug(f"Found: {local_device}")
  261. self.__cloud_device["ip"] = local_device.get("ip")
  262. self.__cloud_device["version"] = local_device.get("version")
  263. self.__cloud_device["local_product_id"] = local_device.get("productKey")
  264. else:
  265. _LOGGER.warning(f"Could not find device: {self.__cloud_device['id']}")
  266. return await self.async_step_local()
  267. return self.async_show_form(
  268. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  269. )
  270. async def async_step_local(self, user_input=None):
  271. errors = {}
  272. devid_opts = {}
  273. host_opts = {"default": ""}
  274. key_opts = {}
  275. proto_opts = {"default": 3.3}
  276. polling_opts = {"default": False}
  277. devcid_opts = {}
  278. if self.__cloud_device is not None:
  279. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  280. devid_opts = {"default": self.__cloud_device["id"]}
  281. host_opts = {"default": self.__cloud_device["ip"]}
  282. key_opts = {"default": self.__cloud_device[CONF_LOCAL_KEY]}
  283. if self.__cloud_device["version"] is not None:
  284. proto_opts = {"default": float(self.__cloud_device["version"])}
  285. if self.__cloud_device[CONF_DEVICE_CID] is not None:
  286. devcid_opts = {"default": self.__cloud_device[CONF_DEVICE_CID]}
  287. if user_input is not None:
  288. self.device = await async_test_connection(user_input, self.hass)
  289. if self.device:
  290. self.data = user_input
  291. if self.__cloud_device:
  292. if self.__cloud_device.get("product_id"):
  293. self.device.set_detected_product_id(
  294. self.__cloud_device["product_id"]
  295. )
  296. if self.__cloud_device.get("local_product_id"):
  297. self.device.set_detected_product_id(
  298. self.__cloud_device["local_product_id"]
  299. )
  300. return await self.async_step_select_type()
  301. else:
  302. errors["base"] = "connection"
  303. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  304. host_opts["default"] = user_input[CONF_HOST]
  305. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  306. if CONF_DEVICE_CID in user_input:
  307. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  308. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  309. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  310. return self.async_show_form(
  311. step_id="local",
  312. data_schema=vol.Schema(
  313. {
  314. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  315. vol.Required(CONF_HOST, **host_opts): str,
  316. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  317. vol.Required(
  318. CONF_PROTOCOL_VERSION,
  319. **proto_opts,
  320. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  321. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  322. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  323. }
  324. ),
  325. errors=errors,
  326. )
  327. async def async_step_select_type(self, user_input=None):
  328. if user_input is not None:
  329. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  330. return await self.async_step_choose_entities()
  331. types = []
  332. best_match = 0
  333. best_matching_type = None
  334. async for type in self.device.async_possible_types():
  335. types.append(type.config_type)
  336. q = type.match_quality(
  337. self.device._get_cached_state(),
  338. self.device._product_ids,
  339. )
  340. if q > best_match:
  341. best_match = q
  342. best_matching_type = type.config_type
  343. best_match = int(best_match)
  344. dps = self.device._get_cached_state()
  345. if self.__cloud_device:
  346. _LOGGER.warning(
  347. "Adding %s device with product id %s",
  348. self.__cloud_device["product_name"],
  349. self.__cloud_device["product_id"],
  350. )
  351. if (
  352. self.__cloud_device["local_product_id"]
  353. and self.__cloud_device["local_product_id"]
  354. != self.__cloud_device["product_id"]
  355. ):
  356. _LOGGER.warning(
  357. "Local product id differs from cloud: %s",
  358. self.__cloud_device["local_product_id"],
  359. )
  360. # try:
  361. # self.init_cloud()
  362. # model = await self.cloud.async_get_datamodel(
  363. # self.__cloud_device["id"],
  364. # )
  365. # if model:
  366. # _LOGGER.warning(
  367. # "Device specficication:\n%s",
  368. # json.dumps(model, indent=4),
  369. # )
  370. # except Exception as e:
  371. # _LOGGER.warning("Unable to fetch data model from cloud: %s", e)
  372. _LOGGER.warning(
  373. "Device matches %s with quality of %d%%. DPS: %s",
  374. best_matching_type,
  375. best_match,
  376. log_json(dps),
  377. )
  378. _LOGGER.warning(
  379. "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
  380. )
  381. if types:
  382. return self.async_show_form(
  383. step_id="select_type",
  384. data_schema=vol.Schema(
  385. {
  386. vol.Required(
  387. CONF_TYPE,
  388. default=best_matching_type,
  389. ): vol.In(types),
  390. }
  391. ),
  392. )
  393. else:
  394. return self.async_abort(reason="not_supported")
  395. async def async_step_choose_entities(self, user_input=None):
  396. if user_input is not None:
  397. title = user_input[CONF_NAME]
  398. del user_input[CONF_NAME]
  399. return self.async_create_entry(
  400. title=title, data={**self.data, **user_input}
  401. )
  402. config = await self.hass.async_add_executor_job(
  403. get_config,
  404. self.data[CONF_TYPE],
  405. )
  406. schema = {vol.Required(CONF_NAME, default=config.name): str}
  407. return self.async_show_form(
  408. step_id="choose_entities",
  409. data_schema=vol.Schema(schema),
  410. )
  411. @staticmethod
  412. @callback
  413. def async_get_options_flow(config_entry):
  414. return OptionsFlowHandler(config_entry)
  415. class OptionsFlowHandler(config_entries.OptionsFlow):
  416. def __init__(self, config_entry):
  417. """Initialize options flow."""
  418. self.config_entry = config_entry
  419. async def async_step_init(self, user_input=None):
  420. return await self.async_step_user(user_input)
  421. async def async_step_user(self, user_input=None):
  422. """Manage the options."""
  423. errors = {}
  424. config = {**self.config_entry.data, **self.config_entry.options}
  425. if user_input is not None:
  426. config = {**config, **user_input}
  427. device = await async_test_connection(config, self.hass)
  428. if device:
  429. return self.async_create_entry(title="", data=user_input)
  430. else:
  431. errors["base"] = "connection"
  432. schema = {
  433. vol.Required(
  434. CONF_LOCAL_KEY,
  435. default=config.get(CONF_LOCAL_KEY, ""),
  436. ): str,
  437. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  438. vol.Required(
  439. CONF_PROTOCOL_VERSION,
  440. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  441. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  442. vol.Required(
  443. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  444. ): bool,
  445. }
  446. cfg = await self.hass.async_add_executor_job(
  447. get_config,
  448. config[CONF_TYPE],
  449. )
  450. if cfg is None:
  451. return self.async_abort(reason="not_supported")
  452. return self.async_show_form(
  453. step_id="user",
  454. data_schema=vol.Schema(schema),
  455. errors=errors,
  456. )
  457. def create_test_device(hass: HomeAssistant, config: dict):
  458. """Set up a tuya device based on passed in config."""
  459. subdevice_id = config.get(CONF_DEVICE_CID)
  460. device = TuyaLocalDevice(
  461. "Test",
  462. config[CONF_DEVICE_ID],
  463. config[CONF_HOST],
  464. config[CONF_LOCAL_KEY],
  465. config[CONF_PROTOCOL_VERSION],
  466. subdevice_id,
  467. hass,
  468. True,
  469. )
  470. return device
  471. async def async_test_connection(config: dict, hass: HomeAssistant):
  472. domain_data = hass.data.get(DOMAIN)
  473. existing = domain_data.get(get_device_id(config)) if domain_data else None
  474. if existing and existing.get("device"):
  475. _LOGGER.info("Pausing existing device to test new connection parameters")
  476. existing["device"].pause()
  477. await asyncio.sleep(5)
  478. try:
  479. device = await hass.async_add_executor_job(
  480. create_test_device,
  481. hass,
  482. config,
  483. )
  484. await device.async_refresh()
  485. retval = device if device.has_returned_state else None
  486. except Exception as e:
  487. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  488. retval = None
  489. if existing and existing.get("device"):
  490. _LOGGER.info("Restarting device after test")
  491. existing["device"].resume()
  492. return retval
  493. def scan_for_device(id):
  494. return tinytuya.find_device(dev_id=id)