config_flow.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import asyncio
  2. import logging
  3. import voluptuous as vol
  4. from homeassistant import config_entries
  5. from homeassistant.const import CONF_HOST, CONF_NAME
  6. from homeassistant.core import HomeAssistant, callback
  7. from . import DOMAIN
  8. from .const import (
  9. API_PROTOCOL_VERSIONS,
  10. CONF_DEVICE_CID,
  11. CONF_DEVICE_ID,
  12. CONF_LOCAL_KEY,
  13. CONF_POLL_ONLY,
  14. CONF_PROTOCOL_VERSION,
  15. CONF_TYPE,
  16. )
  17. from .device import TuyaLocalDevice
  18. from .helpers.config import get_device_id
  19. from .helpers.device_config import get_config
  20. from .helpers.log import log_json
  21. _LOGGER = logging.getLogger(__name__)
  22. class ConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
  23. VERSION = 12
  24. CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
  25. device = None
  26. data = {}
  27. async def async_step_user(self, user_input=None):
  28. errors = {}
  29. devid_opts = {}
  30. host_opts = {"default": "Auto"}
  31. key_opts = {}
  32. proto_opts = {"default": 3.3}
  33. polling_opts = {"default": False}
  34. devcid_opts = {}
  35. if user_input is not None:
  36. await self.async_set_unique_id(get_device_id(user_input))
  37. self._abort_if_unique_id_configured()
  38. self.device = await async_test_connection(user_input, self.hass)
  39. if self.device:
  40. self.data = user_input
  41. return await self.async_step_select_type()
  42. else:
  43. errors["base"] = "connection"
  44. devid_opts["default"] = user_input[CONF_DEVICE_ID]
  45. host_opts["default"] = user_input[CONF_HOST]
  46. key_opts["default"] = user_input[CONF_LOCAL_KEY]
  47. if CONF_DEVICE_CID in user_input:
  48. devcid_opts["default"] = user_input[CONF_DEVICE_CID]
  49. proto_opts["default"] = user_input[CONF_PROTOCOL_VERSION]
  50. polling_opts["default"] = user_input[CONF_POLL_ONLY]
  51. return self.async_show_form(
  52. step_id="user",
  53. data_schema=vol.Schema(
  54. {
  55. vol.Required(CONF_DEVICE_ID, **devid_opts): str,
  56. vol.Required(CONF_HOST, **host_opts): str,
  57. vol.Required(CONF_LOCAL_KEY, **key_opts): str,
  58. vol.Required(
  59. CONF_PROTOCOL_VERSION,
  60. **proto_opts,
  61. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  62. vol.Required(CONF_POLL_ONLY, **polling_opts): bool,
  63. vol.Optional(CONF_DEVICE_CID, **devcid_opts): str,
  64. }
  65. ),
  66. errors=errors,
  67. )
  68. async def async_step_select_type(self, user_input=None):
  69. if user_input is not None:
  70. self.data[CONF_TYPE] = user_input[CONF_TYPE]
  71. return await self.async_step_choose_entities()
  72. types = []
  73. best_match = 0
  74. best_matching_type = None
  75. async for type in self.device.async_possible_types():
  76. types.append(type.config_type)
  77. q = type.match_quality(self.device._get_cached_state())
  78. if q > best_match:
  79. best_match = q
  80. best_matching_type = type.config_type
  81. if best_match < 100:
  82. best_match = int(best_match)
  83. dps = self.device._get_cached_state()
  84. _LOGGER.warning(
  85. "Device matches %s with quality of %d%%. DPS: %s",
  86. best_matching_type,
  87. best_match,
  88. log_json(dps),
  89. )
  90. _LOGGER.warning(
  91. "Report this to https://github.com/make-all/tuya-local/issues/"
  92. )
  93. if types:
  94. return self.async_show_form(
  95. step_id="select_type",
  96. data_schema=vol.Schema(
  97. {
  98. vol.Required(
  99. CONF_TYPE,
  100. default=best_matching_type,
  101. ): vol.In(types),
  102. }
  103. ),
  104. )
  105. else:
  106. return self.async_abort(reason="not_supported")
  107. async def async_step_choose_entities(self, user_input=None):
  108. if user_input is not None:
  109. title = user_input[CONF_NAME]
  110. del user_input[CONF_NAME]
  111. return self.async_create_entry(
  112. title=title, data={**self.data, **user_input}
  113. )
  114. config = get_config(self.data[CONF_TYPE])
  115. schema = {vol.Required(CONF_NAME, default=config.name): str}
  116. return self.async_show_form(
  117. step_id="choose_entities",
  118. data_schema=vol.Schema(schema),
  119. )
  120. @staticmethod
  121. @callback
  122. def async_get_options_flow(config_entry):
  123. return OptionsFlowHandler(config_entry)
  124. class OptionsFlowHandler(config_entries.OptionsFlow):
  125. def __init__(self, config_entry):
  126. """Initialize options flow."""
  127. self.config_entry = config_entry
  128. async def async_step_init(self, user_input=None):
  129. return await self.async_step_user(user_input)
  130. async def async_step_user(self, user_input=None):
  131. """Manage the options."""
  132. errors = {}
  133. config = {**self.config_entry.data, **self.config_entry.options}
  134. if user_input is not None:
  135. config = {**config, **user_input}
  136. device = await async_test_connection(config, self.hass)
  137. if device:
  138. return self.async_create_entry(title="", data=user_input)
  139. else:
  140. errors["base"] = "connection"
  141. schema = {
  142. vol.Required(
  143. CONF_LOCAL_KEY,
  144. default=config.get(CONF_LOCAL_KEY, ""),
  145. ): str,
  146. vol.Required(CONF_HOST, default=config.get(CONF_HOST, "")): str,
  147. vol.Required(
  148. CONF_PROTOCOL_VERSION,
  149. default=config.get(CONF_PROTOCOL_VERSION, "auto"),
  150. ): vol.In(["auto"] + API_PROTOCOL_VERSIONS),
  151. vol.Required(
  152. CONF_POLL_ONLY, default=config.get(CONF_POLL_ONLY, False)
  153. ): bool,
  154. vol.Optional(
  155. CONF_DEVICE_CID,
  156. default=config.get(CONF_DEVICE_CID, ""),
  157. ): str,
  158. }
  159. cfg = get_config(config[CONF_TYPE])
  160. if cfg is None:
  161. return self.async_abort(reason="not_supported")
  162. return self.async_show_form(
  163. step_id="user",
  164. data_schema=vol.Schema(schema),
  165. errors=errors,
  166. )
  167. async def async_test_connection(config: dict, hass: HomeAssistant):
  168. domain_data = hass.data.get(DOMAIN)
  169. existing = domain_data.get(get_device_id(config)) if domain_data else None
  170. if existing:
  171. existing["device"].pause()
  172. await asyncio.sleep(5)
  173. try:
  174. subdevice_id = config.get(CONF_DEVICE_CID)
  175. device = TuyaLocalDevice(
  176. "Test",
  177. config[CONF_DEVICE_ID],
  178. config[CONF_HOST],
  179. config[CONF_LOCAL_KEY],
  180. config[CONF_PROTOCOL_VERSION],
  181. subdevice_id,
  182. hass,
  183. True,
  184. )
  185. await device.async_refresh()
  186. retval = device if device.has_returned_state else None
  187. except Exception as e:
  188. _LOGGER.warning("Connection test failed with %s %s", type(e), e)
  189. retval = None
  190. if existing:
  191. existing["device"].resume()
  192. return retval