config_flow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. import asyncio
  2. import logging
  3. from collections import OrderedDict
  4. from typing import Any
  5. import tinytuya
  6. import voluptuous as vol
  7. from homeassistant import config_entries
  8. from homeassistant.const import CONF_HOST, CONF_NAME
  9. from homeassistant.core import HomeAssistant, callback
  10. from homeassistant.data_entry_flow import FlowResult
  11. from homeassistant.helpers.selector import (
  12. QrCodeSelector,
  13. QrCodeSelectorConfig,
  14. QrErrorCorrectionLevel,
  15. SelectOptionDict,
  16. SelectSelector,
  17. SelectSelectorConfig,
  18. SelectSelectorMode,
  19. )
  20. from . import DOMAIN
  21. from .cloud import Cloud
  22. from .const import (
  23. API_PROTOCOL_VERSIONS,
  24. CONF_DEVICE_CID,
  25. CONF_DEVICE_ID,
  26. CONF_LOCAL_KEY,
  27. CONF_POLL_ONLY,
  28. CONF_PROTOCOL_VERSION,
  29. CONF_TYPE,
  30. CONF_USER_CODE,
  31. DATA_STORE,
  32. )
  33. from .device import TuyaLocalDevice
  34. from .helpers.config import get_device_id
  35. from .helpers.device_config import get_config
  36. from .helpers.log import log_json
  37. _LOGGER = logging.getLogger(__name__)
  38. class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
  39. VERSION = 13
  40. MINOR_VERSION = 6
  41. CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
  42. device = None
  43. data = {}
  44. __qr_code: str | None = None
  45. __cloud_devices: dict[str, Any] = {}
  46. __cloud_device: dict[str, Any] | None = None
  47. def __init__(self) -> None:
  48. """Initialize the config flow."""
  49. self.cloud = Cloud(self.hass)
  50. async def async_step_user(self, user_input=None):
  51. errors = {}
  52. if self.hass.data.get(DOMAIN) is None:
  53. self.hass.data[DOMAIN] = {}
  54. if self.hass.data[DOMAIN].get(DATA_STORE) is None:
  55. self.hass.data[DOMAIN][DATA_STORE] = {}
  56. if user_input is not None:
  57. if user_input["setup_mode"] == "cloud":
  58. try:
  59. if self.cloud.is_authenticated:
  60. self.__cloud_devices = await self.cloud.async_get_devices()
  61. return await self.async_step_choose_device()
  62. except Exception as e:
  63. # Re-authentication is needed.
  64. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  65. _LOGGER.warning("Re-authentication is required.")
  66. return await self.async_step_cloud()
  67. if user_input["setup_mode"] == "manual":
  68. return await self.async_step_local()
  69. # Build form
  70. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  71. fields[vol.Required("setup_mode")] = SelectSelector(
  72. SelectSelectorConfig(
  73. options=["cloud", "manual"],
  74. mode=SelectSelectorMode.LIST,
  75. translation_key="setup_mode",
  76. )
  77. )
  78. return self.async_show_form(
  79. step_id="user",
  80. data_schema=vol.Schema(fields),
  81. errors=errors or {},
  82. last_step=False,
  83. )
  84. async def async_step_cloud(
  85. self, user_input: dict[str, Any] | None = None
  86. ) -> FlowResult:
  87. """Step user."""
  88. errors = {}
  89. placeholders = {}
  90. if user_input is not None:
  91. response = await self.cloud.async_get_qr_code(user_input[CONF_USER_CODE])
  92. if response:
  93. self.__qr_code = response
  94. return await self.async_step_scan()
  95. errors["base"] = "login_error"
  96. placeholders = self.cloud.last_error
  97. else:
  98. user_input = {}
  99. return self.async_show_form(
  100. step_id="cloud",
  101. data_schema=vol.Schema(
  102. {
  103. vol.Required(
  104. CONF_USER_CODE, default=user_input.get(CONF_USER_CODE, "")
  105. ): str,
  106. }
  107. ),
  108. errors=errors,
  109. description_placeholders=placeholders,
  110. )
  111. async def async_step_scan(
  112. self, user_input: dict[str, Any] | None = None
  113. ) -> FlowResult:
  114. """Step scan."""
  115. if user_input is None:
  116. return self.async_show_form(
  117. step_id="scan",
  118. data_schema=vol.Schema(
  119. {
  120. vol.Optional("QR"): QrCodeSelector(
  121. config=QrCodeSelectorConfig(
  122. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  123. scale=5,
  124. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  125. )
  126. )
  127. }
  128. ),
  129. )
  130. if not await self.cloud.async_login():
  131. # Try to get a new QR code on failure
  132. response = await self.cloud.async_get_qr_code()
  133. errors = {"base": "login_error"}
  134. placeholders = self.cloud.last_error
  135. if response:
  136. self.__qr_code = response
  137. return self.async_show_form(
  138. step_id="scan",
  139. errors=errors,
  140. data_schema=vol.Schema(
  141. {
  142. vol.Optional("QR"): QrCodeSelector(
  143. config=QrCodeSelectorConfig(
  144. data=f"tuyaSmart--qrLogin?token={self.__qr_code}",
  145. scale=5,
  146. error_correction_level=QrErrorCorrectionLevel.QUARTILE,
  147. )
  148. )
  149. }
  150. ),
  151. description_placeholders=placeholders,
  152. )
  153. self.__cloud_devices = await self.cloud.async_get_devices()
  154. return await self.async_step_choose_device()
  155. async def async_step_choose_device(self, user_input=None):
  156. errors = {}
  157. if user_input is not None:
  158. device_choice = self.__cloud_devices[user_input["device_id"]]
  159. if device_choice["ip"] != "":
  160. # This is a directly addable device.
  161. if user_input["hub_id"] == "None":
  162. device_choice["ip"] = ""
  163. self.__cloud_device = device_choice
  164. return await self.async_step_search()
  165. else:
  166. # Show error if user selected a hub.
  167. errors["base"] = "does_not_need_hub"
  168. # Fall through to reshow the form.
  169. else:
  170. # This is an indirectly addressable device. Need to know which hub it is connected to.
  171. if user_input["hub_id"] != "None":
  172. hub_choice = self.__cloud_devices[user_input["hub_id"]]
  173. # Populate uuid and local_key from the child device to pass on complete information to the local step.
  174. hub_choice["ip"] = ""
  175. hub_choice[CONF_DEVICE_CID] = device_choice["uuid"]
  176. hub_choice[CONF_LOCAL_KEY] = device_choice[CONF_LOCAL_KEY]
  177. self.__cloud_device = hub_choice
  178. return await self.async_step_search()
  179. else:
  180. # Show error if user did not select a hub.
  181. errors["base"] = "needs_hub"
  182. # Fall through to reshow the form.
  183. device_list = []
  184. for key in self.__cloud_devices.keys():
  185. device_entry = self.__cloud_devices[key]
  186. if device_entry.get("exists"):
  187. continue
  188. if device_entry[CONF_LOCAL_KEY] != "":
  189. if device_entry["online"]:
  190. device_list.append(
  191. SelectOptionDict(
  192. value=key,
  193. label=f"{device_entry['name']} ({device_entry['product_name']})",
  194. )
  195. )
  196. else:
  197. device_list.append(
  198. SelectOptionDict(
  199. value=key,
  200. label=f"{device_entry['name']} ({device_entry['product_name']}) OFFLINE",
  201. )
  202. )
  203. _LOGGER.debug(f"Device count: {len(device_list)}")
  204. if len(device_list) == 0:
  205. return self.async_abort(reason="no_devices")
  206. device_selector = SelectSelector(
  207. SelectSelectorConfig(options=device_list, mode=SelectSelectorMode.DROPDOWN)
  208. )
  209. hub_list = []
  210. hub_list.append(SelectOptionDict(value="None", label="None"))
  211. for key in self.__cloud_devices.keys():
  212. hub_entry = self.__cloud_devices[key]
  213. if hub_entry["is_hub"]:
  214. hub_list.append(
  215. SelectOptionDict(
  216. value=key,
  217. label=f"{hub_entry['name']} ({hub_entry['product_name']})",
  218. )
  219. )
  220. _LOGGER.debug(f"Hub count: {len(hub_list) - 1}")
  221. hub_selector = SelectSelector(
  222. SelectSelectorConfig(options=hub_list, mode=SelectSelectorMode.DROPDOWN)
  223. )
  224. # Build form
  225. fields: OrderedDict[vol.Marker, Any] = OrderedDict()
  226. fields[vol.Required("device_id")] = device_selector
  227. fields[vol.Required("hub_id")] = hub_selector
  228. return self.async_show_form(
  229. step_id="choose_device",
  230. data_schema=vol.Schema(fields),
  231. errors=errors or {},
  232. last_step=False,
  233. )
  234. async def async_step_search(self, user_input=None):
  235. if user_input is not None:
  236. # Current IP is the WAN IP which is of no use. Need to try and discover to the local IP.
  237. # This scan will take 18s with the default settings. If we cannot find the device we
  238. # will just leave the IP address blank and hope the user can discover the IP by other
  239. # means such as router device IP assignments.
  240. _LOGGER.debug(
  241. f"Scanning network to get IP address for {self.__cloud_device['id']}."
  242. )
  243. self.__cloud_device["ip"] = ""
  244. try:
  245. local_device = await self.hass.async_add_executor_job(
  246. scan_for_device, self.__cloud_device["id"]
  247. )
  248. except OSError:
  249. local_device = {"ip": None, "version": ""}
  250. if local_device["ip"] is not None:
  251. _LOGGER.debug(f"Found: {local_device}")
  252. self.__cloud_device["ip"] = local_device["ip"]
  253. self.__cloud_device["version"] = local_device["version"]
  254. else:
  255. _LOGGER.warning(f"Could not find device: {self.__cloud_device['id']}")
  256. return await self.async_step_local()
  257. return self.async_show_form(
  258. step_id="search", data_schema=vol.Schema({}), errors={}, last_step=False
  259. )
  260. async def async_step_local(self, user_input=None):
  261. errors = {}
  262. devid_opts = {}
  263. host_opts = {"default": ""}
  264. key_opts = {}
  265. proto_opts = {"default": 3.3}
  266. polling_opts = {"default": False}
  267. devcid_opts = {}
  268. if self.__cloud_device is not None:
  269. # We already have some or all of the device settings from the cloud flow. Set them into the defaults.
  270. devid_opts = {"default": self.__cloud_device["id"]}
  271. host_opts = {"default": self.__cloud_device["ip"]}
  272. key_opts = {"default": self.__cloud_device[CONF_LOCAL_KEY]}
  273. if self.__cloud_device["version"] is not None:
  274. proto_opts = {"default": float(self.__cloud_device["version"])}
  275. if self.__cloud_device[CONF_DEVICE_CID] is not None:
  276. devcid_opts = {"default": self.__cloud_device[CONF_DEVICE_CID]}
  277. if user_input is not None:
  278. self.device = await async_test_connection(user_input, self.hass)
  279. if self.device:
  280. self.data = user_input
  281. return await self.async_step_select_type()
  282. else:
  283. errors["base"] = "connection"
  284. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  285. host_opts["default"] = user_input[CONF_HOST]
  286. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  287. if CONF_DEVICE_CID in user_input:
  288. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  289. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  290. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  291. return self.async_show_form(
  292. step_id="local",
  293. data_schema=vol.Schema(
  294. {
  295. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  296. vol.Required(CONF_HOST, **host_opts): str,
  297. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  298. vol.Required(
  299. CONF_PROTOCOL_VERSION,
  300. **proto_opts,
  301. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  302. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  303. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  304. }
  305. ),
  306. errors=errors,
  307. )
  308. async def async_step_select_type(self, user_input=None):
  309. if user_input is not None:
  310. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  311. return await self.async_step_choose_entities()
  312. types = []
  313. best_match = 0
  314. best_matching_type = None
  315. async for type in self.device.async_possible_types():
  316. types.append(type.config_type)
  317. q = type.match_quality(self.device._get_cached_state())
  318. if q > best_match:
  319. best_match = q
  320. best_matching_type = type.config_type
  321. best_match = int(best_match)
  322. dps = self.device._get_cached_state()
  323. if self.__cloud_device:
  324. _LOGGER.warning(
  325. "Adding %s device with product id %s",
  326. self.__cloud_device["product_name"],
  327. self.__cloud_device["product_id"],
  328. )
  329. response = self.cloud.async_get_datamodel(self.__cloud_device["device_id"])
  330. if response and response["result"] and response["result"]["model"]:
  331. model = json.loads(response["result"]["model"])
  332. _LOGGER.warning(
  333. "QueryThingsDataModel result:\n%s",
  334. json.dumps(model, indent=4),
  335. )
  336. _LOGGER.warning(
  337. "Device matches %s with quality of %d%%. DPS: %s",
  338. best_matching_type,
  339. best_match,
  340. log_json(dps),
  341. )
  342. _LOGGER.warning(
  343. "Include the previous log messages with any new device request to https://github.com/make-all/tuya-local/issues/",
  344. )
  345. if types:
  346. return self.async_show_form(
  347. step_id="select_type",
  348. data_schema=vol.Schema(
  349. {
  350. vol.Required(
  351. CONF_TYPE,
  352. default=best_matching_type,
  353. ): vol.In(types),
  354. }
  355. ),
  356. )
  357. else:
  358. return self.async_abort(reason="not_supported")
  359. async def async_step_choose_entities(self, user_input=None):
  360. if user_input is not None:
  361. title = user_input[CONF_NAME]
  362. del user_input[CONF_NAME]
  363. return self.async_create_entry(
  364. title=title, data={**self.data, **user_input}
  365. )
  366. config = await self.hass.async_add_executor_job(
  367. get_config,
  368. self.data[CONF_TYPE],
  369. )
  370. schema = {vol.Required(CONF_NAME, default=config.name): str}
  371. return self.async_show_form(
  372. step_id="choose_entities",
  373. data_schema=vol.Schema(schema),
  374. )
  375. @staticmethod
  376. @callback
  377. def async_get_options_flow(config_entry):
  378. return OptionsFlowHandler(config_entry)
  379. class OptionsFlowHandler(config_entries.OptionsFlow):
  380. def __init__(self, config_entry):
  381. """Initialize options flow."""
  382. self.config_entry = config_entry
  383. async def async_step_init(self, user_input=None):
  384. return await self.async_step_user(user_input)
  385. async def async_step_user(self, user_input=None):
  386. """Manage the options."""
  387. errors = {}
  388. config = {**self.config_entry.data, **self.config_entry.options}
  389. if user_input is not None:
  390. config = {**config, **user_input}
  391. device = await async_test_connection(config, self.hass)
  392. if device:
  393. return self.async_create_entry(title="", data=user_input)
  394. else:
  395. errors["base"] = "connection"
  396. schema = {
  397. vol.Required(
  398. CONF_LOCAL_KEY,
  399. default=config.get(CONF_LOCAL_KEY, ""),
  400. ): str,
  401. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  402. vol.Required(
  403. CONF_PROTOCOL_VERSION,
  404. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  405. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  406. vol.Required(
  407. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  408. ): bool,
  409. vol.Optional(
  410. CONF_DEVICE_CID,
  411. default=config.get(CONF_DEVICE_CID, ""),
  412. ): str,
  413. }
  414. cfg = await self.hass.async_add_executor_job(
  415. get_config,
  416. config[CONF_TYPE],
  417. )
  418. if cfg is None:
  419. return self.async_abort(reason="not_supported")
  420. return self.async_show_form(
  421. step_id="user",
  422. data_schema=vol.Schema(schema),
  423. errors=errors,
  424. )
  425. def create_test_device(hass: HomeAssistant, config: dict):
  426. """Set up a tuya device based on passed in config."""
  427. subdevice_id = config.get(CONF_DEVICE_CID)
  428. device = TuyaLocalDevice(
  429. "Test",
  430. config[CONF_DEVICE_ID],
  431. config[CONF_HOST],
  432. config[CONF_LOCAL_KEY],
  433. config[CONF_PROTOCOL_VERSION],
  434. subdevice_id,
  435. hass,
  436. True,
  437. )
  438. return device
  439. async def async_test_connection(config: dict, hass: HomeAssistant):
  440. domain_data = hass.data.get(DOMAIN)
  441. existing = domain_data.get(get_device_id(config)) if domain_data else None
  442. if existing and existing.get("device"):
  443. _LOGGER.info("Pausing existing device to test new connection parameters")
  444. existing["device"].pause()
  445. await asyncio.sleep(5)
  446. try:
  447. device = await hass.async_add_executor_job(
  448. create_test_device,
  449. hass,
  450. config,
  451. )
  452. await device.async_refresh()
  453. retval = device if device.has_returned_state else None
  454. except Exception as e:
  455. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  456. retval = None
  457. if existing and existing.get("device"):
  458. _LOGGER.info("Restarting device after test")
  459. existing["device"].resume()
  460. return retval
  461. def scan_for_device(id):
  462. return tinytuya.find_device(dev_id=id)