config_flow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import asyncio
  2. import json
  3. import logging
  4. from collections import OrderedDict
  5. from typing import Any
  6. import tinytuya
  7. import voluptuous as vol
  8. from homeassistant import config_entries
  9. from homeassistant.const import CONF_HOST, CONF_NAME
  10. from homeassistant.core import HomeAssistant, callback
  11. from homeassistant.data_entry_flow import FlowResult
  12. from homeassistant.helpers.selector import (
  13. QrCodeSelector,
  14. QrCodeSelectorConfig,
  15. QrErrorCorrectionLevel,
  16. SelectOptionDict,
  17. SelectSelector,
  18. SelectSelectorConfig,
  19. SelectSelectorMode,
  20. )
  21. from . import DOMAIN
  22. from .cloud import Cloud
  23. from .const import (
  24. API_PROTOCOL_VERSIONS,
  25. CONF_DEVICE_CID,
  26. CONF_DEVICE_ID,
  27. CONF_LOCAL_KEY,
  28. CONF_POLL_ONLY,
  29. CONF_PROTOCOL_VERSION,
  30. CONF_TYPE,
  31. CONF_USER_CODE,
  32. DATA_STORE,
  33. )
  34. from .device import TuyaLocalDevice
  35. from .helpers.config import get_device_id
  36. from .helpers.device_config import get_config
  37. from .helpers.log import log_json
  38. _LOGGER = logging.getLogger(__name__)
  39. class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
  40. VERSION = 13
  41. MINOR_VERSION = 6
  42. CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
  43. device = None
  44. data = {}
  45. __qr_code: str | None = None
  46. __cloud_devices: dict[str, Any] = {}
  47. __cloud_device: dict[str, Any] | None = None
  48. def __init__(self) -> None:
  49. """Initialize the config flow."""
  50. self.cloud = None
  51. def init_cloud(self):
  52. if self.cloud is None:
  53. self.cloud = Cloud(self.hass)
  54. async def async_step_user(self, user_input=None):
  55. errors = {}
  56. if self.hass.data.get(DOMAIN) is None:
  57. self.hass.data[DOMAIN] = {}
  58. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  59. self.hass.data[DOMAIN][DATA_STORE] = {}
  60. if user_input is not None:
  61. if user_input["setup_mode"] == "cloud":
  62. self.init_cloud()
  63. try:
  64. if self.cloud.is_authenticated:
  65. self.__cloud_devices = await self.cloud.async_get_devices()
  66. return await self.async_step_choose_device()
  67. except Exception as e:
  68. # Re-authentication is needed.
  69. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  70. _LOGGER.warning("Re-authentication is required.")
  71. return await self.async_step_cloud()
  72. if user_input["setup_mode"] == "manual":
  73. return await self.async_step_local()
  74. # Build form
  75. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  76. fields[vol.Required("setup_mode")] = SelectSelector(
  77. SelectSelectorConfig(
  78. options=["cloud", "manual"],
  79. mode=SelectSelectorMode.LIST,
  80. translation_key="setup_mode",
  81. )
  82. )
  83. return self.async_show_form(
  84. step_id="user",
  85. data_schema=vol.Schema(fields),
  86. errors=errors or {},
  87. last_step=False,
  88. )
  89. async def async_step_cloud(
  90. self, user_input: dict[str, Any] | None = None
  91. ) -> FlowResult:
  92. """Step user."""
  93. errors = {}
  94. placeholders = {}
  95. self.init_cloud()
  96. if user_input is not None:
  97. response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
  98. if response:
  99. self.__qr_code = response
  100. return await self.async_step_scan()
  101. errors["base"] = "login_error"
  102. placeholders = self.cloud.last_error
  103. else:
  104. user_input = {}
  105. return self.async_show_form(
  106. step_id="cloud",
  107. data_schema=vol.Schema(
  108. {
  109. vol.Required(
  110. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  111. ): str,
  112. }
  113. ),
  114. errors=errors,
  115. description_placeholders=placeholders,
  116. )
  117. async def async_step_scan(
  118. self, user_input: dict[str, Any] | None = None
  119. ) -> FlowResult:
  120. """Step scan."""
  121. if user_input is None:
  122. return self.async_show_form(
  123. step_id="scan",
  124. data_schema=vol.Schema(
  125. {
  126. vol.Optional("QR"): QrCodeSelector(
  127. config=QrCodeSelectorConfig(
  128. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  129. scale=5,
  130. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  131. )
  132. )
  133. }
  134. ),
  135. )
  136. self.init_cloud()
  137. if not await self.cloud.async_login():
  138. # Try to get a new QR code on failure
  139. response = await self.cloud.async_get_qr_code()
  140. errors = {"base": "login_error"}
  141. placeholders = self.cloud.last_error
  142. if response:
  143. self.__qr_code = response
  144. return self.async_show_form(
  145. step_id="scan",
  146. errors=errors,
  147. data_schema=vol.Schema(
  148. {
  149. vol.Optional("QR"): QrCodeSelector(
  150. config=QrCodeSelectorConfig(
  151. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  152. scale=5,
  153. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  154. )
  155. )
  156. }
  157. ),
  158. description_placeholders=placeholders,
  159. )
  160. self.__cloud_devices = await self.cloud.async_get_devices()
  161. return await self.async_step_choose_device()
  162. async def async_step_choose_device(self, user_input=None):
  163. errors = {}
  164. if user_input is not None:
  165. device_choice = self.__cloud_devices[user_input["device_id"]]
  166. if device_choice["ip"] != "":
  167. # This is a directly addable device.
  168. if user_input["hub_id"] == "None":
  169. device_choice["ip"] = ""
  170. self.__cloud_device = device_choice
  171. return await self.async_step_search()
  172. else:
  173. # Show error if user selected a hub.
  174. errors["base"] = "does_not_need_hub"
  175. # Fall through to reshow the form.
  176. else:
  177. # This is an indirectly addressable device. Need to know which hub it is connected to.
  178. if user_input["hub_id"] != "None":
  179. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  180. # Populate uuid and local_key from the child device to pass on complete information to the local step.
  181. hub_choice["ip"] = ""
  182. hub_choice[CONF_DEVICE_CID] = device_choice["uuid"]
  183. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  184. self.__cloud_device = hub_choice
  185. return await self.async_step_search()
  186. else:
  187. # Show error if user did not select a hub.
  188. errors["base"] = "needs_hub"
  189. # Fall through to reshow the form.
  190. device_list = []
  191. for key in self.__cloud_devices.keys():
  192. device_entry = self.__cloud_devices[key]
  193. if device_entry.get("exists"):
  194. continue
  195. if device_entry[CONF_LOCAL_KEY] != "":
  196. if device_entry["online"]:
  197. device_list.append(
  198. SelectOptionDict(
  199. value=key,
  200. label=f"{device_entry['name']} ({device_entry['product_name']})",
  201. )
  202. )
  203. else:
  204. device_list.append(
  205. SelectOptionDict(
  206. value=key,
  207. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  208. )
  209. )
  210. _LOGGER.debug(f"Device count: {len(device_list)}")
  211. if len(device_list) == 0:
  212. return self.async_abort(reason="no_devices")
  213. device_selector = SelectSelector(
  214. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  215. )
  216. hub_list = []
  217. hub_list.append(SelectOptionDict(value="None", label="None"))
  218. for key in self.__cloud_devices.keys():
  219. hub_entry = self.__cloud_devices[key]
  220. if hub_entry["is_hub"]:
  221. hub_list.append(
  222. SelectOptionDict(
  223. value=key,
  224. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  225. )
  226. )
  227. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  228. hub_selector = SelectSelector(
  229. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  230. )
  231. # Build form
  232. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  233. fields[vol.Required("device_id")] = device_selector
  234. fields[vol.Required("hub_id")] = hub_selector
  235. return self.async_show_form(
  236. step_id="choose_device",
  237. data_schema=vol.Schema(fields),
  238. errors=errors or {},
  239. last_step=False,
  240. )
  241. async def async_step_search(self, user_input=None):
  242. if user_input is not None:
  243. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  244. # This scan will take 18s with the default settings. If we cannot find the device we
  245. # will just leave the IP address blank and hope the user can discover the IP by other
  246. # means such as router device IP assignments.
  247. _LOGGER.debug(
  248. f"Scanning network to get IP address for {self.__cloud_device['id']}."
  249. )
  250. self.__cloud_device["ip"] = ""
  251. try:
  252. local_device = await self.hass.async_add_executor_job(
  253. scan_for_device, self.__cloud_device["id"]
  254. )
  255. except OSError:
  256. local_device = {"ip": None, "version": ""}
  257. if local_device["ip"] is not None:
  258. _LOGGER.debug(f"Found: {local_device}")
  259. self.__cloud_device["ip"] = local_device["ip"]
  260. self.__cloud_device["version"] = local_device["version"]
  261. else:
  262. _LOGGER.warning(f"Could not find device: {self.__cloud_device['id']}")
  263. return await self.async_step_local()
  264. return self.async_show_form(
  265. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  266. )
  267. async def async_step_local(self, user_input=None):
  268. errors = {}
  269. devid_opts = {}
  270. host_opts = {"default": ""}
  271. key_opts = {}
  272. proto_opts = {"default": 3.3}
  273. polling_opts = {"default": False}
  274. devcid_opts = {}
  275. if self.__cloud_device is not None:
  276. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  277. devid_opts = {"default": self.__cloud_device["id"]}
  278. host_opts = {"default": self.__cloud_device["ip"]}
  279. key_opts = {"default": self.__cloud_device[CONF_LOCAL_KEY]}
  280. if self.__cloud_device["version"] is not None:
  281. proto_opts = {"default": float(self.__cloud_device["version"])}
  282. if self.__cloud_device[CONF_DEVICE_CID] is not None:
  283. devcid_opts = {"default": self.__cloud_device[CONF_DEVICE_CID]}
  284. if user_input is not None:
  285. self.device = await async_test_connection(user_input, self.hass)
  286. if self.device:
  287. self.data = user_input
  288. return await self.async_step_select_type()
  289. else:
  290. errors["base"] = "connection"
  291. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  292. host_opts["default"] = user_input[CONF_HOST]
  293. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  294. if CONF_DEVICE_CID in user_input:
  295. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  296. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  297. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  298. return self.async_show_form(
  299. step_id="local",
  300. data_schema=vol.Schema(
  301. {
  302. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  303. vol.Required(CONF_HOST, **host_opts): str,
  304. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  305. vol.Required(
  306. CONF_PROTOCOL_VERSION,
  307. **proto_opts,
  308. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  309. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  310. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  311. }
  312. ),
  313. errors=errors,
  314. )
  315. async def async_step_select_type(self, user_input=None):
  316. if user_input is not None:
  317. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  318. return await self.async_step_choose_entities()
  319. types = []
  320. best_match = 0
  321. best_matching_type = None
  322. async for type in self.device.async_possible_types():
  323. types.append(type.config_type)
  324. q = type.match_quality(self.device._get_cached_state())
  325. if q > best_match:
  326. best_match = q
  327. best_matching_type = type.config_type
  328. best_match = int(best_match)
  329. dps = self.device._get_cached_state()
  330. if self.__cloud_device:
  331. _LOGGER.warning(
  332. "Adding %s device with product id %s",
  333. self.__cloud_device["product_name"],
  334. self.__cloud_device["product_id"],
  335. )
  336. # try:
  337. # self.init_cloud()
  338. # model = await self.cloud.async_get_datamodel(
  339. # self.__cloud_device["id"],
  340. # )
  341. # if model:
  342. # _LOGGER.warning(
  343. # "Device specficication:\n%s",
  344. # json.dumps(model, indent=4),
  345. # )
  346. # except Exception as e:
  347. # _LOGGER.warning("Unable to fetch data model from cloud: %s", e)
  348. _LOGGER.warning(
  349. "Device matches %s with quality of %d%%. DPS: %s",
  350. best_matching_type,
  351. best_match,
  352. log_json(dps),
  353. )
  354. _LOGGER.warning(
  355. "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
  356. )
  357. if types:
  358. return self.async_show_form(
  359. step_id="select_type",
  360. data_schema=vol.Schema(
  361. {
  362. vol.Required(
  363. CONF_TYPE,
  364. default=best_matching_type,
  365. ): vol.In(types),
  366. }
  367. ),
  368. )
  369. else:
  370. return self.async_abort(reason="not_supported")
  371. async def async_step_choose_entities(self, user_input=None):
  372. if user_input is not None:
  373. title = user_input[CONF_NAME]
  374. del user_input[CONF_NAME]
  375. return self.async_create_entry(
  376. title=title, data={**self.data, **user_input}
  377. )
  378. config = await self.hass.async_add_executor_job(
  379. get_config,
  380. self.data[CONF_TYPE],
  381. )
  382. schema = {vol.Required(CONF_NAME, default=config.name): str}
  383. return self.async_show_form(
  384. step_id="choose_entities",
  385. data_schema=vol.Schema(schema),
  386. )
  387. @staticmethod
  388. @callback
  389. def async_get_options_flow(config_entry):
  390. return OptionsFlowHandler(config_entry)
  391. class OptionsFlowHandler(config_entries.OptionsFlow):
  392. def __init__(self, config_entry):
  393. """Initialize options flow."""
  394. self.config_entry = config_entry
  395. async def async_step_init(self, user_input=None):
  396. return await self.async_step_user(user_input)
  397. async def async_step_user(self, user_input=None):
  398. """Manage the options."""
  399. errors = {}
  400. config = {**self.config_entry.data, **self.config_entry.options}
  401. if user_input is not None:
  402. config = {**config, **user_input}
  403. device = await async_test_connection(config, self.hass)
  404. if device:
  405. return self.async_create_entry(title="", data=user_input)
  406. else:
  407. errors["base"] = "connection"
  408. schema = {
  409. vol.Required(
  410. CONF_LOCAL_KEY,
  411. default=config.get(CONF_LOCAL_KEY, ""),
  412. ): str,
  413. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  414. vol.Required(
  415. CONF_PROTOCOL_VERSION,
  416. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  417. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  418. vol.Required(
  419. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  420. ): bool,
  421. vol.Optional(
  422. CONF_DEVICE_CID,
  423. default=config.get(CONF_DEVICE_CID, ""),
  424. ): str,
  425. }
  426. cfg = await self.hass.async_add_executor_job(
  427. get_config,
  428. config[CONF_TYPE],
  429. )
  430. if cfg is None:
  431. return self.async_abort(reason="not_supported")
  432. return self.async_show_form(
  433. step_id="user",
  434. data_schema=vol.Schema(schema),
  435. errors=errors,
  436. )
  437. def create_test_device(hass: HomeAssistant, config: dict):
  438. """Set up a tuya device based on passed in config."""
  439. subdevice_id = config.get(CONF_DEVICE_CID)
  440. device = TuyaLocalDevice(
  441. "Test",
  442. config[CONF_DEVICE_ID],
  443. config[CONF_HOST],
  444. config[CONF_LOCAL_KEY],
  445. config[CONF_PROTOCOL_VERSION],
  446. subdevice_id,
  447. hass,
  448. True,
  449. )
  450. return device
  451. async def async_test_connection(config: dict, hass: HomeAssistant):
  452. domain_data = hass.data.get(DOMAIN)
  453. existing = domain_data.get(get_device_id(config)) if domain_data else None
  454. if existing and existing.get("device"):
  455. _LOGGER.info("Pausing existing device to test new connection parameters")
  456. existing["device"].pause()
  457. await asyncio.sleep(5)
  458. try:
  459. device = await hass.async_add_executor_job(
  460. create_test_device,
  461. hass,
  462. config,
  463. )
  464. await device.async_refresh()
  465. retval = device if device.has_returned_state else None
  466. except Exception as e:
  467. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  468. retval = None
  469. if existing and existing.get("device"):
  470. _LOGGER.info("Restarting device after test")
  471. existing["device"].resume()
  472. return retval
  473. def scan_for_device(id):
  474. return tinytuya.find_device(dev_id=id)