device.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """
  2. API for Tuya Local devices.
  3. """
  4. import asyncio
  5. import json
  6. import logging
  7. import tinytuya
  8. from threading import Lock
  9. from time import time
  10. from homeassistant.const import (
  11. CONF_HOST,
  12. CONF_NAME,
  13. EVENT_HOMEASSISTANT_STARTED,
  14. EVENT_HOMEASSISTANT_STOP,
  15. )
  16. from homeassistant.core import HomeAssistant
  17. from .const import (
  18. API_PROTOCOL_VERSIONS,
  19. CONF_DEVICE_ID,
  20. CONF_LOCAL_KEY,
  21. CONF_PROTOCOL_VERSION,
  22. DOMAIN,
  23. )
  24. from .helpers.device_config import possible_matches
  25. _LOGGER = logging.getLogger(__name__)
  26. class TuyaLocalDevice(object):
  27. def __init__(
  28. self,
  29. name,
  30. dev_id,
  31. address,
  32. local_key,
  33. protocol_version,
  34. hass: HomeAssistant,
  35. ):
  36. """
  37. Represents a Tuya-based device.
  38. Args:
  39. dev_id (str): The device id.
  40. address (str): The network address.
  41. local_key (str): The encryption key.
  42. protocol_version (str | number): The protocol version.
  43. """
  44. self._name = name
  45. self._children = []
  46. self._running = False
  47. self._shutdown_listener = None
  48. self._startup_listener = None
  49. self._api_protocol_version_index = None
  50. self._api_protocol_working = False
  51. self._api = tinytuya.Device(dev_id, address, local_key)
  52. self._refresh_task = None
  53. self._protocol_configured = protocol_version
  54. self._reset_cached_state()
  55. self._hass = hass
  56. # API calls to update Tuya devices are asynchronous and non-blocking.
  57. # This means you can send a change and immediately request an updated
  58. # state (like HA does), but because it has not yet finished processing
  59. # you will be returned the old state.
  60. # The solution is to keep a temporary list of changed properties that
  61. # we can overlay onto the state while we wait for the board to update
  62. # its switches.
  63. self._FAKE_IT_TIMEOUT = 5
  64. self._CACHE_TIMEOUT = 120
  65. # More attempts are needed in auto mode so we can cycle through all
  66. # the possibilities a couple of times
  67. self._AUTO_CONNECTION_ATTEMPTS = 9
  68. self._SINGLE_PROTO_CONNECTION_ATTEMPTS = 3
  69. self._lock = Lock()
  70. @property
  71. def name(self):
  72. return self._name
  73. @property
  74. def unique_id(self):
  75. """Return the unique id for this device (the dev_id)."""
  76. return self._api.id
  77. @property
  78. def device_info(self):
  79. """Return the device information for this device."""
  80. return {
  81. "identifiers": {(DOMAIN, self.unique_id)},
  82. "name": self.name,
  83. "manufacturer": "Tuya",
  84. }
  85. @property
  86. def has_returned_state(self):
  87. """Return True if the device has returned some state."""
  88. return len(self._get_cached_state()) > 1
  89. def actually_start(self, event=None):
  90. _LOGGER.debug(f"Starting monitor loop for {self.name}")
  91. self._running = True
  92. self._shutdown_listener = self._hass.bus.async_listen_once(
  93. EVENT_HOMEASSISTANT_STOP, self.async_stop
  94. )
  95. self._refresh_task = self._hass.async_create_task(self.receive_loop())
  96. def start(self):
  97. if self._hass.is_stopping:
  98. return
  99. elif self._hass.is_running:
  100. if self._startup_listener:
  101. self._startup_listener()
  102. self._startup_listener = None
  103. self.actually_start()
  104. else:
  105. self._startup_listener = self._hass.bus.async_listen_once(
  106. EVENT_HOMEASSISTANT_STARTED, self.actually_start
  107. )
  108. async def async_stop(self, event=None):
  109. _LOGGER.debug(f"Stopping monitor loop for {self.name}")
  110. self._running = False
  111. if self._shutdown_listener:
  112. self._shutdown_listener()
  113. self._shutdown_listener = None
  114. self._children.clear()
  115. if self._refresh_task:
  116. await self._refresh_task
  117. _LOGGER.debug(f"Monitor loop for {self.name} stopped")
  118. self._refresh_task = None
  119. def register_entity(self, entity):
  120. self._children.append(entity)
  121. if not self._running and not self._startup_listener:
  122. self.start()
  123. async def async_unregister_entity(self, entity):
  124. self._children.remove(entity)
  125. if not self._children:
  126. await self.async_stop()
  127. async def receive_loop(self):
  128. """Coroutine wrapper for async_receive generator."""
  129. try:
  130. async for poll in self.async_receive():
  131. if type(poll) is dict:
  132. _LOGGER.debug(f"{self.name} received {poll}")
  133. self._cached_state = self._cached_state | poll
  134. self._cached_state["updated_at"] = time()
  135. for entity in self._children:
  136. entity.async_schedule_update_ha_state()
  137. else:
  138. _LOGGER.debug(f"{self.name} received non data {poll}")
  139. _LOGGER.warning(f"{self.name} receive loop has terminated")
  140. except Exception as t:
  141. _LOGGER.exception(
  142. f"{self.name} receive loop terminated by exception {t}",
  143. )
  144. async def async_receive(self):
  145. """Receive messages from a persistent connection asynchronously."""
  146. self._api.set_socketPersistent(self._running)
  147. while self._running:
  148. try:
  149. last_cache = self._cached_state["updated_at"]
  150. now = time()
  151. if now - last_cache > self._CACHE_TIMEOUT:
  152. poll = await self._retry_on_failed_connection(
  153. lambda: self._api.status(),
  154. f"Failed to refresh device state for {self.name}",
  155. )
  156. else:
  157. await self._hass.async_add_executor_job(
  158. self._api.heartbeat,
  159. True,
  160. )
  161. poll = await self._hass.async_add_executor_job(
  162. self._api.receive,
  163. )
  164. if poll:
  165. if "Error" in poll:
  166. _LOGGER.warning(
  167. f"{self.name} error reading: {poll['Error']}",
  168. )
  169. if "Payload" in poll and poll["Payload"]:
  170. _LOGGER.info(
  171. f"{self.name} err payload: {poll['Payload']}",
  172. )
  173. else:
  174. if "dps" in poll:
  175. poll = poll["dps"]
  176. yield poll
  177. await asyncio.sleep(0.1)
  178. except asyncio.CancelledError:
  179. self._running = False
  180. # Close the persistent connection when exiting the loop
  181. self._api.set_socketPersistent(False)
  182. raise
  183. except Exception as t:
  184. _LOGGER.exception(
  185. f"{self.name} receive loop error {type(t)}:{t}",
  186. )
  187. # Close the persistent connection when exiting the loop
  188. self._api.set_socketPersistent(False)
  189. async def async_possible_types(self):
  190. cached_state = self._get_cached_state()
  191. if len(cached_state) <= 1:
  192. await self.async_refresh()
  193. cached_state = self._get_cached_state()
  194. for match in possible_matches(cached_state):
  195. yield match
  196. async def async_inferred_type(self):
  197. best_match = None
  198. best_quality = 0
  199. cached_state = {}
  200. async for config in self.async_possible_types():
  201. cached_state = self._get_cached_state()
  202. quality = config.match_quality(cached_state)
  203. _LOGGER.info(
  204. f"{self.name} considering {config.name} with quality {quality}"
  205. )
  206. if quality > best_quality:
  207. best_quality = quality
  208. best_match = config
  209. if best_match is None:
  210. _LOGGER.warning(
  211. f"Detection for {self.name} with dps {cached_state} failed",
  212. )
  213. return None
  214. return best_match.config_type
  215. async def async_refresh(self):
  216. _LOGGER.debug(f"Refreshing device state for {self.name}.")
  217. await self._retry_on_failed_connection(
  218. lambda: self._refresh_cached_state(),
  219. f"Failed to refresh device state for {self.name}.",
  220. )
  221. def get_property(self, dps_id):
  222. cached_state = self._get_cached_state()
  223. if dps_id in cached_state:
  224. return cached_state[dps_id]
  225. else:
  226. return None
  227. async def async_set_property(self, dps_id, value):
  228. await self.async_set_properties({dps_id: value})
  229. def anticipate_property_value(self, dps_id, value):
  230. """
  231. Update a value in the cached state only. This is good for when you
  232. know the device will reflect a new state in the next update, but
  233. don't want to wait for that update for the device to represent
  234. this state.
  235. The anticipated value will be cleared with the next update.
  236. """
  237. self._cached_state[dps_id] = value
  238. def _reset_cached_state(self):
  239. self._cached_state = {"updated_at": 0}
  240. self._pending_updates = {}
  241. self._last_connection = 0
  242. def _refresh_cached_state(self):
  243. new_state = self._api.status()
  244. self._cached_state = self._cached_state | new_state["dps"]
  245. self._cached_state["updated_at"] = time()
  246. _LOGGER.debug(
  247. f"{self.name} refreshed device state: {json.dumps(new_state)}",
  248. )
  249. _LOGGER.debug(
  250. f"new state (incl pending): {json.dumps(self._get_cached_state())}"
  251. )
  252. async def async_set_properties(self, properties):
  253. if len(properties) == 0:
  254. return
  255. self._add_properties_to_pending_updates(properties)
  256. await self._debounce_sending_updates()
  257. def _add_properties_to_pending_updates(self, properties):
  258. now = time()
  259. pending_updates = self._get_pending_updates()
  260. for key, value in properties.items():
  261. pending_updates[key] = {
  262. "value": value,
  263. "updated_at": now,
  264. "sent": False,
  265. }
  266. _LOGGER.debug(
  267. f"{self.name} new pending updates: {json.dumps(pending_updates)}",
  268. )
  269. async def _debounce_sending_updates(self):
  270. now = time()
  271. since = now - self._last_connection
  272. # set this now to avoid a race condition, it will be updated later
  273. # when the data is actally sent
  274. self._last_connection = now
  275. # Only delay a second if there was recently another command.
  276. # Otherwise delay 1ms, to keep things simple by reusing the
  277. # same send mechanism.
  278. waittime = 1 if since < 1.1 else 0.001
  279. await asyncio.sleep(waittime)
  280. await self._send_pending_updates()
  281. async def _send_pending_updates(self):
  282. pending_properties = self._get_unsent_properties()
  283. payload = self._api.generate_payload(
  284. tinytuya.CONTROL,
  285. pending_properties,
  286. )
  287. _LOGGER.debug(
  288. f"{self.name} sending dps update: {json.dumps(pending_properties)}"
  289. )
  290. await self._retry_on_failed_connection(
  291. lambda: self._send_payload(payload),
  292. "Failed to update device state.",
  293. )
  294. def _send_payload(self, payload):
  295. try:
  296. self._lock.acquire()
  297. self._api.send(payload)
  298. self._cached_state["updated_at"] = 0
  299. now = time()
  300. self._last_connection = now
  301. pending_updates = self._get_pending_updates()
  302. for key in list(pending_updates):
  303. pending_updates[key]["updated_at"] = now
  304. pending_updates[key]["sent"] = True
  305. finally:
  306. self._lock.release()
  307. async def _retry_on_failed_connection(self, func, error_message):
  308. if self._api_protocol_version_index is None:
  309. self._rotate_api_protocol_version()
  310. auto = (self._protocol_configured == "auto") and (
  311. not self._api_protocol_working
  312. )
  313. connections = (
  314. self._AUTO_CONNECTION_ATTEMPTS
  315. if auto
  316. else self._SINGLE_PROTO_CONNECTION_ATTEMPTS
  317. )
  318. for i in range(connections):
  319. try:
  320. retval = await self._hass.async_add_executor_job(func)
  321. self._api_protocol_working = True
  322. return retval
  323. except Exception as e:
  324. _LOGGER.debug(f"Retrying after exception {e}")
  325. if i + 1 == connections:
  326. self._reset_cached_state()
  327. self._api_protocol_working = False
  328. _LOGGER.error(error_message)
  329. if not self._api_protocol_working:
  330. self._rotate_api_protocol_version()
  331. def _get_cached_state(self):
  332. cached_state = self._cached_state.copy()
  333. return {**cached_state, **self._get_pending_properties()}
  334. def _get_pending_properties(self):
  335. return {
  336. key: property["value"]
  337. for key, property in self._get_pending_updates().items()
  338. }
  339. def _get_unsent_properties(self):
  340. return {
  341. key: info["value"]
  342. for key, info in self._get_pending_updates().items()
  343. if not info["sent"]
  344. }
  345. def _get_pending_updates(self):
  346. now = time()
  347. self._pending_updates = {
  348. key: value
  349. for key, value in self._pending_updates.items()
  350. if now - value["updated_at"] < self._FAKE_IT_TIMEOUT
  351. }
  352. return self._pending_updates
  353. def _rotate_api_protocol_version(self):
  354. if self._api_protocol_version_index is None:
  355. try:
  356. self._api_protocol_version_index = API_PROTOCOL_VERSIONS.index(
  357. self._protocol_configured
  358. )
  359. except ValueError:
  360. self._api_protocol_version_index = 0
  361. # only rotate if configured as auto
  362. elif self._protocol_configured == "auto":
  363. self._api_protocol_version_index += 1
  364. if self._api_protocol_version_index >= len(API_PROTOCOL_VERSIONS):
  365. self._api_protocol_version_index = 0
  366. new_version = API_PROTOCOL_VERSIONS[self._api_protocol_version_index]
  367. _LOGGER.info(
  368. f"Setting protocol version for {self.name} to {new_version}.",
  369. )
  370. self._api.set_version(new_version)
  371. @staticmethod
  372. def get_key_for_value(obj, value, fallback=None):
  373. keys = list(obj.keys())
  374. values = list(obj.values())
  375. return keys[values.index(value)] if value in values else fallback
  376. def setup_device(hass: HomeAssistant, config: dict):
  377. """Setup a tuya device based on passed in config."""
  378. _LOGGER.info(f"Creating device: {config[CONF_DEVICE_ID]}")
  379. hass.data[DOMAIN] = hass.data.get(DOMAIN, {})
  380. device = TuyaLocalDevice(
  381. config[CONF_NAME],
  382. config[CONF_DEVICE_ID],
  383. config[CONF_HOST],
  384. config[CONF_LOCAL_KEY],
  385. config[CONF_PROTOCOL_VERSION],
  386. hass,
  387. )
  388. hass.data[DOMAIN][config[CONF_DEVICE_ID]] = {"device": device}
  389. return device
  390. async def async_delete_device(hass: HomeAssistant, config: dict):
  391. _LOGGER.info(f"Deleting device: {config[CONF_DEVICE_ID]}")
  392. await hass.data[DOMAIN][config[CONF_DEVICE_ID]]["device"].async_stop()
  393. del hass.data[DOMAIN][config[CONF_DEVICE_ID]]["device"]