device.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. """
  2. API for Tuya Local devices.
  3. """
  4. import asyncio
  5. import json
  6. import logging
  7. import tinytuya
  8. from threading import Lock
  9. from time import time
  10. from homeassistant.const import (
  11. CONF_HOST,
  12. CONF_NAME,
  13. EVENT_HOMEASSISTANT_STARTED,
  14. EVENT_HOMEASSISTANT_STOP,
  15. )
  16. from homeassistant.core import HomeAssistant
  17. from .const import (
  18. API_PROTOCOL_VERSIONS,
  19. CONF_DEVICE_ID,
  20. CONF_LOCAL_KEY,
  21. CONF_POLL_ONLY,
  22. CONF_PROTOCOL_VERSION,
  23. DOMAIN,
  24. )
  25. from .helpers.device_config import possible_matches
  26. _LOGGER = logging.getLogger(__name__)
  27. def non_json(input):
  28. """Handler for json_dumps when used for debugging."""
  29. return f"Non-JSON: ({input})"
  30. class TuyaLocalDevice(object):
  31. def __init__(
  32. self,
  33. name,
  34. dev_id,
  35. address,
  36. local_key,
  37. protocol_version,
  38. hass: HomeAssistant,
  39. poll_only=False,
  40. ):
  41. """
  42. Represents a Tuya-based device.
  43. Args:
  44. dev_id (str): The device id.
  45. address (str): The network address.
  46. local_key (str): The encryption key.
  47. protocol_version (str | number): The protocol version.
  48. hass (HomeAssistant): The Home Assistant instance.
  49. poll_only (bool): True if the device should be polled only
  50. """
  51. self._name = name
  52. self._children = []
  53. self._force_dps = []
  54. self._running = False
  55. self._shutdown_listener = None
  56. self._startup_listener = None
  57. self._api_protocol_version_index = None
  58. self._api_protocol_working = False
  59. self._api = tinytuya.Device(dev_id, address, local_key)
  60. # we handle retries at a higher level so we can rotate protocol version
  61. self._api.set_socketRetryLimit(1)
  62. self._refresh_task = None
  63. self._protocol_configured = protocol_version
  64. self._poll_only = poll_only
  65. self._temporary_poll = False
  66. self._reset_cached_state()
  67. self._hass = hass
  68. # API calls to update Tuya devices are asynchronous and non-blocking.
  69. # This means you can send a change and immediately request an updated
  70. # state (like HA does), but because it has not yet finished processing
  71. # you will be returned the old state.
  72. # The solution is to keep a temporary list of changed properties that
  73. # we can overlay onto the state while we wait for the board to update
  74. # its switches.
  75. self._FAKE_IT_TIMEOUT = 5
  76. self._CACHE_TIMEOUT = 120
  77. # More attempts are needed in auto mode so we can cycle through all
  78. # the possibilities a couple of times
  79. self._AUTO_CONNECTION_ATTEMPTS = len(API_PROTOCOL_VERSIONS) * 2 + 1
  80. self._SINGLE_PROTO_CONNECTION_ATTEMPTS = 3
  81. self._lock = Lock()
  82. @property
  83. def name(self):
  84. return self._name
  85. @property
  86. def unique_id(self):
  87. """Return the unique id for this device (the dev_id)."""
  88. return self._api.id
  89. @property
  90. def device_info(self):
  91. """Return the device information for this device."""
  92. return {
  93. "identifiers": {(DOMAIN, self.unique_id)},
  94. "name": self.name,
  95. "manufacturer": "Tuya",
  96. }
  97. @property
  98. def has_returned_state(self):
  99. """Return True if the device has returned some state."""
  100. return len(self._get_cached_state()) > 1
  101. def actually_start(self, event=None):
  102. _LOGGER.debug(f"Starting monitor loop for {self.name}")
  103. self._running = True
  104. self._shutdown_listener = self._hass.bus.async_listen_once(
  105. EVENT_HOMEASSISTANT_STOP, self.async_stop
  106. )
  107. self._refresh_task = self._hass.async_create_task(self.receive_loop())
  108. def start(self):
  109. if self._hass.is_stopping:
  110. return
  111. elif self._hass.is_running:
  112. if self._startup_listener:
  113. self._startup_listener()
  114. self._startup_listener = None
  115. self.actually_start()
  116. else:
  117. self._startup_listener = self._hass.bus.async_listen_once(
  118. EVENT_HOMEASSISTANT_STARTED, self.actually_start
  119. )
  120. async def async_stop(self, event=None):
  121. _LOGGER.debug(f"Stopping monitor loop for {self.name}")
  122. self._running = False
  123. if self._shutdown_listener:
  124. self._shutdown_listener()
  125. self._shutdown_listener = None
  126. self._children.clear()
  127. self._force_dps.clear()
  128. if self._refresh_task:
  129. await self._refresh_task
  130. _LOGGER.debug(f"Monitor loop for {self.name} stopped")
  131. self._refresh_task = None
  132. def register_entity(self, entity):
  133. # If this is the first child entity to register, refresh the device
  134. # state
  135. should_poll = len(self._children) == 0
  136. self._children.append(entity)
  137. for dp in entity._config.dps():
  138. if dp.force and dp.id not in self._force_dps:
  139. self._force_dps.append(dp.id)
  140. if not self._running and not self._startup_listener:
  141. self.start()
  142. if self.has_returned_state:
  143. entity.async_schedule_update_ha_state()
  144. elif should_poll:
  145. entity.async_schedule_update_ha_state(True)
  146. async def async_unregister_entity(self, entity):
  147. self._children.remove(entity)
  148. if not self._children:
  149. await self.async_stop()
  150. async def receive_loop(self):
  151. """Coroutine wrapper for async_receive generator."""
  152. try:
  153. async for poll in self.async_receive():
  154. if type(poll) is dict:
  155. _LOGGER.debug(f"{self.name} received {poll}")
  156. self._cached_state = self._cached_state | poll
  157. self._cached_state["updated_at"] = time()
  158. for entity in self._children:
  159. entity.async_schedule_update_ha_state()
  160. else:
  161. _LOGGER.debug(f"{self.name} received non data {poll}")
  162. _LOGGER.warning(f"{self.name} receive loop has terminated")
  163. except Exception as t:
  164. _LOGGER.exception(
  165. f"{self.name} receive loop terminated by exception {t}",
  166. )
  167. @property
  168. def should_poll(self):
  169. return self._poll_only or self._temporary_poll or not self.has_returned_state
  170. def pause(self):
  171. self._temporary_poll = True
  172. def resume(self):
  173. self._temporary_poll = False
  174. async def async_receive(self):
  175. """Receive messages from a persistent connection asynchronously."""
  176. # If we didn't yet get any state from the device, we may need to
  177. # negotiate the protocol before making the connection persistent
  178. persist = not self.should_poll
  179. self._api.set_socketPersistent(persist)
  180. while self._running:
  181. try:
  182. last_cache = self._cached_state["updated_at"]
  183. now = time()
  184. if persist == self.should_poll:
  185. # use persistent connections after initial communication
  186. # has been established. Until then, we need to rotate
  187. # the protocol version, which seems to require a fresh
  188. # connection.
  189. persist = not self.should_poll
  190. self._api.set_socketPersistent(persist)
  191. if now - last_cache > self._CACHE_TIMEOUT:
  192. if self._force_dps:
  193. poll = await self._retry_on_failed_connection(
  194. lambda: self._api.updatedps(self._force_dps),
  195. f"Failed to refresh device state for {self.name}",
  196. )
  197. else:
  198. poll = await self._retry_on_failed_connection(
  199. lambda: self._api.status(),
  200. f"Failed to refresh device state for {self.name}",
  201. )
  202. else:
  203. await self._hass.async_add_executor_job(
  204. self._api.heartbeat,
  205. True,
  206. )
  207. poll = await self._hass.async_add_executor_job(
  208. self._api.receive,
  209. )
  210. if poll:
  211. if "Error" in poll:
  212. _LOGGER.warning(
  213. f"{self.name} error reading: {poll['Error']}",
  214. )
  215. if "Payload" in poll and poll["Payload"]:
  216. _LOGGER.info(
  217. f"{self.name} err payload: {poll['Payload']}",
  218. )
  219. else:
  220. if "dps" in poll:
  221. poll = poll["dps"]
  222. yield poll
  223. await asyncio.sleep(0.1 if self.has_returned_state else 5)
  224. except asyncio.CancelledError:
  225. self._running = False
  226. # Close the persistent connection when exiting the loop
  227. self._api.set_socketPersistent(False)
  228. raise
  229. except Exception as t:
  230. _LOGGER.exception(
  231. f"{self.name} receive loop error {type(t)}:{t}",
  232. )
  233. await asyncio.sleep(5)
  234. # Close the persistent connection when exiting the loop
  235. self._api.set_socketPersistent(False)
  236. async def async_possible_types(self):
  237. cached_state = self._get_cached_state()
  238. if len(cached_state) <= 1:
  239. await self.async_refresh()
  240. cached_state = self._get_cached_state()
  241. for match in possible_matches(cached_state):
  242. yield match
  243. async def async_inferred_type(self):
  244. best_match = None
  245. best_quality = 0
  246. cached_state = {}
  247. async for config in self.async_possible_types():
  248. cached_state = self._get_cached_state()
  249. quality = config.match_quality(cached_state)
  250. _LOGGER.info(
  251. f"{self.name} considering {config.name} with quality {quality}"
  252. )
  253. if quality > best_quality:
  254. best_quality = quality
  255. best_match = config
  256. if best_match is None:
  257. _LOGGER.warning(
  258. f"Detection for {self.name} with dps {cached_state} failed",
  259. )
  260. return None
  261. return best_match.config_type
  262. async def async_refresh(self):
  263. _LOGGER.debug(f"Refreshing device state for {self.name}.")
  264. await self._retry_on_failed_connection(
  265. lambda: self._refresh_cached_state(),
  266. f"Failed to refresh device state for {self.name}.",
  267. )
  268. def get_property(self, dps_id):
  269. cached_state = self._get_cached_state()
  270. if dps_id in cached_state:
  271. return cached_state[dps_id]
  272. else:
  273. return None
  274. async def async_set_property(self, dps_id, value):
  275. await self.async_set_properties({dps_id: value})
  276. def anticipate_property_value(self, dps_id, value):
  277. """
  278. Update a value in the cached state only. This is good for when you
  279. know the device will reflect a new state in the next update, but
  280. don't want to wait for that update for the device to represent
  281. this state.
  282. The anticipated value will be cleared with the next update.
  283. """
  284. self._cached_state[dps_id] = value
  285. def _reset_cached_state(self):
  286. self._cached_state = {"updated_at": 0}
  287. self._pending_updates = {}
  288. self._last_connection = 0
  289. def _refresh_cached_state(self):
  290. new_state = self._api.status()
  291. self._cached_state = self._cached_state | new_state["dps"]
  292. self._cached_state["updated_at"] = time()
  293. for entity in self._children:
  294. entity.async_schedule_update_ha_state()
  295. _LOGGER.debug(
  296. f"{self.name} refreshed device state: {json.dumps(new_state, default=non_json)}",
  297. )
  298. _LOGGER.debug(
  299. f"new state (incl pending): {json.dumps(self._get_cached_state(), default=non_json)}"
  300. )
  301. async def async_set_properties(self, properties):
  302. if len(properties) == 0:
  303. return
  304. self._add_properties_to_pending_updates(properties)
  305. await self._debounce_sending_updates()
  306. def _add_properties_to_pending_updates(self, properties):
  307. now = time()
  308. pending_updates = self._get_pending_updates()
  309. for key, value in properties.items():
  310. pending_updates[key] = {
  311. "value": value,
  312. "updated_at": now,
  313. "sent": False,
  314. }
  315. _LOGGER.debug(
  316. f"{self.name} new pending updates: {json.dumps(pending_updates, default=non_json)}",
  317. )
  318. async def _debounce_sending_updates(self):
  319. now = time()
  320. since = now - self._last_connection
  321. # set this now to avoid a race condition, it will be updated later
  322. # when the data is actally sent
  323. self._last_connection = now
  324. # Only delay a second if there was recently another command.
  325. # Otherwise delay 1ms, to keep things simple by reusing the
  326. # same send mechanism.
  327. waittime = 1 if since < 1.1 else 0.001
  328. await asyncio.sleep(waittime)
  329. await self._send_pending_updates()
  330. async def _send_pending_updates(self):
  331. pending_properties = self._get_unsent_properties()
  332. _LOGGER.debug(
  333. f"{self.name} sending dps update: {json.dumps(pending_properties, default=non_json)}"
  334. )
  335. await self._retry_on_failed_connection(
  336. lambda: self._set_values(pending_properties),
  337. "Failed to update device state.",
  338. )
  339. def _set_values(self, properties):
  340. try:
  341. self._lock.acquire()
  342. self._api.set_multiple_values(properties, nowait=True)
  343. self._cached_state["updated_at"] = 0
  344. now = time()
  345. self._last_connection = now
  346. pending_updates = self._get_pending_updates()
  347. for key in properties.keys():
  348. pending_updates[key]["updated_at"] = now
  349. pending_updates[key]["sent"] = True
  350. finally:
  351. self._lock.release()
  352. async def _retry_on_failed_connection(self, func, error_message):
  353. if self._api_protocol_version_index is None:
  354. await self._rotate_api_protocol_version()
  355. auto = (self._protocol_configured == "auto") and (
  356. not self._api_protocol_working
  357. )
  358. connections = (
  359. self._AUTO_CONNECTION_ATTEMPTS
  360. if auto
  361. else self._SINGLE_PROTO_CONNECTION_ATTEMPTS
  362. )
  363. for i in range(connections):
  364. try:
  365. if not self._hass.is_stopping:
  366. retval = await self._hass.async_add_executor_job(func)
  367. if type(retval) is dict and "Error" in retval:
  368. raise AttributeError
  369. self._api_protocol_working = True
  370. return retval
  371. except Exception as e:
  372. _LOGGER.debug(
  373. f"Retrying after exception {e} ({i}/{connections})",
  374. )
  375. if i + 1 == connections:
  376. self._reset_cached_state()
  377. self._api_protocol_working = False
  378. for entity in self._children:
  379. entity.async_schedule_update_ha_state()
  380. _LOGGER.error(error_message)
  381. if not self._api_protocol_working:
  382. await self._rotate_api_protocol_version()
  383. def _get_cached_state(self):
  384. cached_state = self._cached_state.copy()
  385. return {**cached_state, **self._get_pending_properties()}
  386. def _get_pending_properties(self):
  387. return {
  388. key: property["value"]
  389. for key, property in self._get_pending_updates().items()
  390. }
  391. def _get_unsent_properties(self):
  392. return {
  393. key: info["value"]
  394. for key, info in self._get_pending_updates().items()
  395. if not info["sent"]
  396. }
  397. def _get_pending_updates(self):
  398. now = time()
  399. self._pending_updates = {
  400. key: value
  401. for key, value in self._pending_updates.items()
  402. if now - value["updated_at"] < self._FAKE_IT_TIMEOUT
  403. }
  404. return self._pending_updates
  405. async def _rotate_api_protocol_version(self):
  406. if self._api_protocol_version_index is None:
  407. try:
  408. self._api_protocol_version_index = API_PROTOCOL_VERSIONS.index(
  409. self._protocol_configured
  410. )
  411. except ValueError:
  412. self._api_protocol_version_index = 0
  413. # only rotate if configured as auto
  414. elif self._protocol_configured == "auto":
  415. self._api_protocol_version_index += 1
  416. if self._api_protocol_version_index >= len(API_PROTOCOL_VERSIONS):
  417. self._api_protocol_version_index = 0
  418. new_version = API_PROTOCOL_VERSIONS[self._api_protocol_version_index]
  419. _LOGGER.info(
  420. f"Setting protocol version for {self.name} to {new_version}.",
  421. )
  422. await self._hass.async_add_executor_job(
  423. self._api.set_version,
  424. new_version,
  425. )
  426. @staticmethod
  427. def get_key_for_value(obj, value, fallback=None):
  428. keys = list(obj.keys())
  429. values = list(obj.values())
  430. return keys[values.index(value)] if value in values else fallback
  431. def setup_device(hass: HomeAssistant, config: dict):
  432. """Setup a tuya device based on passed in config."""
  433. _LOGGER.info(f"Creating device: {config[CONF_DEVICE_ID]}")
  434. hass.data[DOMAIN] = hass.data.get(DOMAIN, {})
  435. device = TuyaLocalDevice(
  436. config[CONF_NAME],
  437. config[CONF_DEVICE_ID],
  438. config[CONF_HOST],
  439. config[CONF_LOCAL_KEY],
  440. config[CONF_PROTOCOL_VERSION],
  441. hass,
  442. config[CONF_POLL_ONLY],
  443. )
  444. hass.data[DOMAIN][config[CONF_DEVICE_ID]] = {"device": device}
  445. return device
  446. async def async_delete_device(hass: HomeAssistant, config: dict):
  447. _LOGGER.info(f"Deleting device: {config[CONF_DEVICE_ID]}")
  448. await hass.data[DOMAIN][config[CONF_DEVICE_ID]]["device"].async_stop()
  449. del hass.data[DOMAIN][config[CONF_DEVICE_ID]]["device"]