config_flow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. import asyncio
  2. import json
  3. import logging
  4. from collections import OrderedDict
  5. from typing import Any
  6. import tinytuya
  7. import voluptuous as vol
  8. from homeassistant import config_entries
  9. from homeassistant.const import CONF_HOST, CONF_NAME
  10. from homeassistant.core import HomeAssistant, callback
  11. from homeassistant.data_entry_flow import FlowResult
  12. from homeassistant.helpers.selector import (
  13. QrCodeSelector,
  14. QrCodeSelectorConfig,
  15. QrErrorCorrectionLevel,
  16. SelectOptionDict,
  17. SelectSelector,
  18. SelectSelectorConfig,
  19. SelectSelectorMode,
  20. )
  21. from . import DOMAIN
  22. from .cloud import Cloud
  23. from .const import (
  24. API_PROTOCOL_VERSIONS,
  25. CONF_DEVICE_CID,
  26. CONF_DEVICE_ID,
  27. CONF_LOCAL_KEY,
  28. CONF_POLL_ONLY,
  29. CONF_PROTOCOL_VERSION,
  30. CONF_TYPE,
  31. CONF_USER_CODE,
  32. DATA_STORE,
  33. )
  34. from .device import TuyaLocalDevice
  35. from .helpers.config import get_device_id
  36. from .helpers.device_config import get_config
  37. from .helpers.log import log_json
  38. _LOGGER = logging.getLogger(__name__)
  39. class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
  40. VERSION = 13
  41. MINOR_VERSION = 6
  42. CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
  43. device = None
  44. data = {}
  45. __qr_code: str | None = None
  46. __cloud_devices: dict[str, Any] = {}
  47. __cloud_device: dict[str, Any] | None = None
  48. def __init__(self) -> None:
  49. """Initialize the config flow."""
  50. self.cloud = Cloud(self.hass)
  51. async def async_step_user(self, user_input=None):
  52. errors = {}
  53. if self.hass.data.get(DOMAIN) is None:
  54. self.hass.data[DOMAIN] = {}
  55. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  56. self.hass.data[DOMAIN][DATA_STORE] = {}
  57. if user_input is not None:
  58. if user_input["setup_mode"] == "cloud":
  59. try:
  60. if self.cloud.is_authenticated:
  61. self.__cloud_devices = await self.cloud.async_get_devices()
  62. return await self.async_step_choose_device()
  63. except Exception as e:
  64. # Re-authentication is needed.
  65. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  66. _LOGGER.warning("Re-authentication is required.")
  67. return await self.async_step_cloud()
  68. if user_input["setup_mode"] == "manual":
  69. return await self.async_step_local()
  70. # Build form
  71. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  72. fields[vol.Required("setup_mode")] = SelectSelector(
  73. SelectSelectorConfig(
  74. options=["cloud", "manual"],
  75. mode=SelectSelectorMode.LIST,
  76. translation_key="setup_mode",
  77. )
  78. )
  79. return self.async_show_form(
  80. step_id="user",
  81. data_schema=vol.Schema(fields),
  82. errors=errors or {},
  83. last_step=False,
  84. )
  85. async def async_step_cloud(
  86. self, user_input: dict[str, Any] | None = None
  87. ) -> FlowResult:
  88. """Step user."""
  89. errors = {}
  90. placeholders = {}
  91. if user_input is not None:
  92. response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
  93. if response:
  94. self.__qr_code = response
  95. return await self.async_step_scan()
  96. errors["base"] = "login_error"
  97. placeholders = self.cloud.last_error
  98. else:
  99. user_input = {}
  100. return self.async_show_form(
  101. step_id="cloud",
  102. data_schema=vol.Schema(
  103. {
  104. vol.Required(
  105. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  106. ): str,
  107. }
  108. ),
  109. errors=errors,
  110. description_placeholders=placeholders,
  111. )
  112. async def async_step_scan(
  113. self, user_input: dict[str, Any] | None = None
  114. ) -> FlowResult:
  115. """Step scan."""
  116. if user_input is None:
  117. return self.async_show_form(
  118. step_id="scan",
  119. data_schema=vol.Schema(
  120. {
  121. vol.Optional("QR"): QrCodeSelector(
  122. config=QrCodeSelectorConfig(
  123. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  124. scale=5,
  125. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  126. )
  127. )
  128. }
  129. ),
  130. )
  131. if not await self.cloud.async_login():
  132. # Try to get a new QR code on failure
  133. response = await self.cloud.async_get_qr_code()
  134. errors = {"base": "login_error"}
  135. placeholders = self.cloud.last_error
  136. if response:
  137. self.__qr_code = response
  138. return self.async_show_form(
  139. step_id="scan",
  140. errors=errors,
  141. data_schema=vol.Schema(
  142. {
  143. vol.Optional("QR"): QrCodeSelector(
  144. config=QrCodeSelectorConfig(
  145. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  146. scale=5,
  147. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  148. )
  149. )
  150. }
  151. ),
  152. description_placeholders=placeholders,
  153. )
  154. self.__cloud_devices = await self.cloud.async_get_devices()
  155. return await self.async_step_choose_device()
  156. async def async_step_choose_device(self, user_input=None):
  157. errors = {}
  158. if user_input is not None:
  159. device_choice = self.__cloud_devices[user_input["device_id"]]
  160. if device_choice["ip"] != "":
  161. # This is a directly addable device.
  162. if user_input["hub_id"] == "None":
  163. device_choice["ip"] = ""
  164. self.__cloud_device = device_choice
  165. return await self.async_step_search()
  166. else:
  167. # Show error if user selected a hub.
  168. errors["base"] = "does_not_need_hub"
  169. # Fall through to reshow the form.
  170. else:
  171. # This is an indirectly addressable device. Need to know which hub it is connected to.
  172. if user_input["hub_id"] != "None":
  173. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  174. # Populate uuid and local_key from the child device to pass on complete information to the local step.
  175. hub_choice["ip"] = ""
  176. hub_choice[CONF_DEVICE_CID] = device_choice["uuid"]
  177. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  178. self.__cloud_device = hub_choice
  179. return await self.async_step_search()
  180. else:
  181. # Show error if user did not select a hub.
  182. errors["base"] = "needs_hub"
  183. # Fall through to reshow the form.
  184. device_list = []
  185. for key in self.__cloud_devices.keys():
  186. device_entry = self.__cloud_devices[key]
  187. if device_entry.get("exists"):
  188. continue
  189. if device_entry[CONF_LOCAL_KEY] != "":
  190. if device_entry["online"]:
  191. device_list.append(
  192. SelectOptionDict(
  193. value=key,
  194. label=f"{device_entry['name']} ({device_entry['product_name']})",
  195. )
  196. )
  197. else:
  198. device_list.append(
  199. SelectOptionDict(
  200. value=key,
  201. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  202. )
  203. )
  204. _LOGGER.debug(f"Device count: {len(device_list)}")
  205. if len(device_list) == 0:
  206. return self.async_abort(reason="no_devices")
  207. device_selector = SelectSelector(
  208. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  209. )
  210. hub_list = []
  211. hub_list.append(SelectOptionDict(value="None", label="None"))
  212. for key in self.__cloud_devices.keys():
  213. hub_entry = self.__cloud_devices[key]
  214. if hub_entry["is_hub"]:
  215. hub_list.append(
  216. SelectOptionDict(
  217. value=key,
  218. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  219. )
  220. )
  221. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  222. hub_selector = SelectSelector(
  223. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  224. )
  225. # Build form
  226. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  227. fields[vol.Required("device_id")] = device_selector
  228. fields[vol.Required("hub_id")] = hub_selector
  229. return self.async_show_form(
  230. step_id="choose_device",
  231. data_schema=vol.Schema(fields),
  232. errors=errors or {},
  233. last_step=False,
  234. )
  235. async def async_step_search(self, user_input=None):
  236. if user_input is not None:
  237. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  238. # This scan will take 18s with the default settings. If we cannot find the device we
  239. # will just leave the IP address blank and hope the user can discover the IP by other
  240. # means such as router device IP assignments.
  241. _LOGGER.debug(
  242. f"Scanning network to get IP address for {self.__cloud_device['id']}."
  243. )
  244. self.__cloud_device["ip"] = ""
  245. try:
  246. local_device = await self.hass.async_add_executor_job(
  247. scan_for_device, self.__cloud_device["id"]
  248. )
  249. except OSError:
  250. local_device = {"ip": None, "version": ""}
  251. if local_device["ip"] is not None:
  252. _LOGGER.debug(f"Found: {local_device}")
  253. self.__cloud_device["ip"] = local_device["ip"]
  254. self.__cloud_device["version"] = local_device["version"]
  255. else:
  256. _LOGGER.warning(f"Could not find device: {self.__cloud_device['id']}")
  257. return await self.async_step_local()
  258. return self.async_show_form(
  259. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  260. )
  261. async def async_step_local(self, user_input=None):
  262. errors = {}
  263. devid_opts = {}
  264. host_opts = {"default": ""}
  265. key_opts = {}
  266. proto_opts = {"default": 3.3}
  267. polling_opts = {"default": False}
  268. devcid_opts = {}
  269. if self.__cloud_device is not None:
  270. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  271. devid_opts = {"default": self.__cloud_device["id"]}
  272. host_opts = {"default": self.__cloud_device["ip"]}
  273. key_opts = {"default": self.__cloud_device[CONF_LOCAL_KEY]}
  274. if self.__cloud_device["version"] is not None:
  275. proto_opts = {"default": float(self.__cloud_device["version"])}
  276. if self.__cloud_device[CONF_DEVICE_CID] is not None:
  277. devcid_opts = {"default": self.__cloud_device[CONF_DEVICE_CID]}
  278. if user_input is not None:
  279. self.device = await async_test_connection(user_input, self.hass)
  280. if self.device:
  281. self.data = user_input
  282. return await self.async_step_select_type()
  283. else:
  284. errors["base"] = "connection"
  285. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  286. host_opts["default"] = user_input[CONF_HOST]
  287. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  288. if CONF_DEVICE_CID in user_input:
  289. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  290. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  291. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  292. return self.async_show_form(
  293. step_id="local",
  294. data_schema=vol.Schema(
  295. {
  296. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  297. vol.Required(CONF_HOST, **host_opts): str,
  298. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  299. vol.Required(
  300. CONF_PROTOCOL_VERSION,
  301. **proto_opts,
  302. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  303. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  304. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  305. }
  306. ),
  307. errors=errors,
  308. )
  309. async def async_step_select_type(self, user_input=None):
  310. if user_input is not None:
  311. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  312. return await self.async_step_choose_entities()
  313. types = []
  314. best_match = 0
  315. best_matching_type = None
  316. async for type in self.device.async_possible_types():
  317. types.append(type.config_type)
  318. q = type.match_quality(self.device._get_cached_state())
  319. if q > best_match:
  320. best_match = q
  321. best_matching_type = type.config_type
  322. best_match = int(best_match)
  323. dps = self.device._get_cached_state()
  324. if self.__cloud_device:
  325. _LOGGER.warning(
  326. "Adding %s device with product id %s",
  327. self.__cloud_device["product_name"],
  328. self.__cloud_device["product_id"],
  329. )
  330. response = self.cloud.async_get_datamodel(self.__cloud_device["device_id"])
  331. if response and response["result"] and response["result"]["model"]:
  332. model = json.loads(response["result"]["model"])
  333. _LOGGER.warning(
  334. "QueryThingsDataModel result:\n%s",
  335. json.dumps(model, indent=4),
  336. )
  337. _LOGGER.warning(
  338. "Device matches %s with quality of %d%%. DPS: %s",
  339. best_matching_type,
  340. best_match,
  341. log_json(dps),
  342. )
  343. _LOGGER.warning(
  344. "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
  345. )
  346. if types:
  347. return self.async_show_form(
  348. step_id="select_type",
  349. data_schema=vol.Schema(
  350. {
  351. vol.Required(
  352. CONF_TYPE,
  353. default=best_matching_type,
  354. ): vol.In(types),
  355. }
  356. ),
  357. )
  358. else:
  359. return self.async_abort(reason="not_supported")
  360. async def async_step_choose_entities(self, user_input=None):
  361. if user_input is not None:
  362. title = user_input[CONF_NAME]
  363. del user_input[CONF_NAME]
  364. return self.async_create_entry(
  365. title=title, data={**self.data, **user_input}
  366. )
  367. config = await self.hass.async_add_executor_job(
  368. get_config,
  369. self.data[CONF_TYPE],
  370. )
  371. schema = {vol.Required(CONF_NAME, default=config.name): str}
  372. return self.async_show_form(
  373. step_id="choose_entities",
  374. data_schema=vol.Schema(schema),
  375. )
  376. @staticmethod
  377. @callback
  378. def async_get_options_flow(config_entry):
  379. return OptionsFlowHandler(config_entry)
  380. class OptionsFlowHandler(config_entries.OptionsFlow):
  381. def __init__(self, config_entry):
  382. """Initialize options flow."""
  383. self.config_entry = config_entry
  384. async def async_step_init(self, user_input=None):
  385. return await self.async_step_user(user_input)
  386. async def async_step_user(self, user_input=None):
  387. """Manage the options."""
  388. errors = {}
  389. config = {**self.config_entry.data, **self.config_entry.options}
  390. if user_input is not None:
  391. config = {**config, **user_input}
  392. device = await async_test_connection(config, self.hass)
  393. if device:
  394. return self.async_create_entry(title="", data=user_input)
  395. else:
  396. errors["base"] = "connection"
  397. schema = {
  398. vol.Required(
  399. CONF_LOCAL_KEY,
  400. default=config.get(CONF_LOCAL_KEY, ""),
  401. ): str,
  402. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  403. vol.Required(
  404. CONF_PROTOCOL_VERSION,
  405. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  406. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  407. vol.Required(
  408. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  409. ): bool,
  410. vol.Optional(
  411. CONF_DEVICE_CID,
  412. default=config.get(CONF_DEVICE_CID, ""),
  413. ): str,
  414. }
  415. cfg = await self.hass.async_add_executor_job(
  416. get_config,
  417. config[CONF_TYPE],
  418. )
  419. if cfg is None:
  420. return self.async_abort(reason="not_supported")
  421. return self.async_show_form(
  422. step_id="user",
  423. data_schema=vol.Schema(schema),
  424. errors=errors,
  425. )
  426. def create_test_device(hass: HomeAssistant, config: dict):
  427. """Set up a tuya device based on passed in config."""
  428. subdevice_id = config.get(CONF_DEVICE_CID)
  429. device = TuyaLocalDevice(
  430. "Test",
  431. config[CONF_DEVICE_ID],
  432. config[CONF_HOST],
  433. config[CONF_LOCAL_KEY],
  434. config[CONF_PROTOCOL_VERSION],
  435. subdevice_id,
  436. hass,
  437. True,
  438. )
  439. return device
  440. async def async_test_connection(config: dict, hass: HomeAssistant):
  441. domain_data = hass.data.get(DOMAIN)
  442. existing = domain_data.get(get_device_id(config)) if domain_data else None
  443. if existing and existing.get("device"):
  444. _LOGGER.info("Pausing existing device to test new connection parameters")
  445. existing["device"].pause()
  446. await asyncio.sleep(5)
  447. try:
  448. device = await hass.async_add_executor_job(
  449. create_test_device,
  450. hass,
  451. config,
  452. )
  453. await device.async_refresh()
  454. retval = device if device.has_returned_state else None
  455. except Exception as e:
  456. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  457. retval = None
  458. if existing and existing.get("device"):
  459. _LOGGER.info("Restarting device after test")
  460. existing["device"].resume()
  461. return retval
  462. def scan_for_device(id):
  463. return tinytuya.find_device(dev_id=id)