device.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. """
  2. API for Tuya Local devices.
  3. """
  4. import asyncio
  5. import logging
  6. import tinytuya
  7. from threading import Lock
  8. from time import time
  9. from homeassistant.const import (
  10. CONF_HOST,
  11. CONF_NAME,
  12. EVENT_HOMEASSISTANT_STARTED,
  13. EVENT_HOMEASSISTANT_STOP,
  14. )
  15. from homeassistant.core import HomeAssistant
  16. from .const import (
  17. API_PROTOCOL_VERSIONS,
  18. CONF_DEVICE_ID,
  19. CONF_LOCAL_KEY,
  20. CONF_POLL_ONLY,
  21. CONF_PROTOCOL_VERSION,
  22. DOMAIN,
  23. )
  24. from .helpers.device_config import possible_matches
  25. from .helpers.log import log_json
  26. _LOGGER = logging.getLogger(__name__)
  27. class TuyaLocalDevice(object):
  28. def __init__(
  29. self,
  30. name,
  31. dev_id,
  32. address,
  33. local_key,
  34. protocol_version,
  35. hass: HomeAssistant,
  36. poll_only=False,
  37. ):
  38. """
  39. Represents a Tuya-based device.
  40. Args:
  41. dev_id (str): The device id.
  42. address (str): The network address.
  43. local_key (str): The encryption key.
  44. protocol_version (str | number): The protocol version.
  45. hass (HomeAssistant): The Home Assistant instance.
  46. poll_only (bool): True if the device should be polled only
  47. """
  48. self._name = name
  49. self._children = []
  50. self._force_dps = []
  51. self._running = False
  52. self._shutdown_listener = None
  53. self._startup_listener = None
  54. self._api_protocol_version_index = None
  55. self._api_protocol_working = False
  56. try:
  57. self._api = tinytuya.Device(dev_id, address, local_key)
  58. except Exception as e:
  59. _LOGGER.error(
  60. "%s: %s while initialising device %s",
  61. type(e),
  62. e,
  63. dev_id,
  64. )
  65. raise e
  66. # we handle retries at a higher level so we can rotate protocol version
  67. self._api.set_socketRetryLimit(1)
  68. self._refresh_task = None
  69. self._protocol_configured = protocol_version
  70. self._poll_only = poll_only
  71. self._temporary_poll = False
  72. self._reset_cached_state()
  73. self._hass = hass
  74. # API calls to update Tuya devices are asynchronous and non-blocking.
  75. # This means you can send a change and immediately request an updated
  76. # state (like HA does), but because it has not yet finished processing
  77. # you will be returned the old state.
  78. # The solution is to keep a temporary list of changed properties that
  79. # we can overlay onto the state while we wait for the board to update
  80. # its switches.
  81. self._FAKE_IT_TIMEOUT = 5
  82. self._CACHE_TIMEOUT = 30
  83. # More attempts are needed in auto mode so we can cycle through all
  84. # the possibilities a couple of times
  85. self._AUTO_CONNECTION_ATTEMPTS = len(API_PROTOCOL_VERSIONS) * 2 + 1
  86. self._SINGLE_PROTO_CONNECTION_ATTEMPTS = 3
  87. self._lock = Lock()
  88. @property
  89. def name(self):
  90. return self._name
  91. @property
  92. def unique_id(self):
  93. """Return the unique id for this device (the dev_id)."""
  94. return self._api.id
  95. @property
  96. def device_info(self):
  97. """Return the device information for this device."""
  98. return {
  99. "identifiers": {(DOMAIN, self.unique_id)},
  100. "name": self.name,
  101. "manufacturer": "Tuya",
  102. }
  103. @property
  104. def has_returned_state(self):
  105. """Return True if the device has returned some state."""
  106. return len(self._get_cached_state()) > 1
  107. def actually_start(self, event=None):
  108. _LOGGER.debug("Starting monitor loop for %s", self.name)
  109. self._running = True
  110. self._shutdown_listener = self._hass.bus.async_listen_once(
  111. EVENT_HOMEASSISTANT_STOP, self.async_stop
  112. )
  113. self._refresh_task = self._hass.async_create_task(self.receive_loop())
  114. def start(self):
  115. if self._hass.is_stopping:
  116. return
  117. elif self._hass.is_running:
  118. if self._startup_listener:
  119. self._startup_listener()
  120. self._startup_listener = None
  121. self.actually_start()
  122. else:
  123. self._startup_listener = self._hass.bus.async_listen_once(
  124. EVENT_HOMEASSISTANT_STARTED, self.actually_start
  125. )
  126. async def async_stop(self, event=None):
  127. _LOGGER.debug("Stopping monitor loop for %s", self.name)
  128. self._running = False
  129. if self._shutdown_listener:
  130. self._shutdown_listener()
  131. self._shutdown_listener = None
  132. self._children.clear()
  133. self._force_dps.clear()
  134. if self._refresh_task:
  135. await self._refresh_task
  136. _LOGGER.debug("Monitor loop for %s stopped", self.name)
  137. self._refresh_task = None
  138. def register_entity(self, entity):
  139. # If this is the first child entity to register, and HA is still
  140. # starting, refresh the device state so it shows as available without
  141. # waiting for startup to complete.
  142. should_poll = len(self._children) == 0 and not self._hass.is_running
  143. self._children.append(entity)
  144. for dp in entity._config.dps():
  145. if dp.force and dp.id not in self._force_dps:
  146. self._force_dps.append(int(dp.id))
  147. if not self._running and not self._startup_listener:
  148. self.start()
  149. if self.has_returned_state:
  150. entity.async_schedule_update_ha_state()
  151. elif should_poll:
  152. entity.async_schedule_update_ha_state(True)
  153. async def async_unregister_entity(self, entity):
  154. self._children.remove(entity)
  155. if not self._children:
  156. await self.async_stop()
  157. async def receive_loop(self):
  158. """Coroutine wrapper for async_receive generator."""
  159. try:
  160. async for poll in self.async_receive():
  161. if type(poll) is dict:
  162. _LOGGER.debug(
  163. "%s received %s",
  164. self.name,
  165. log_json(poll),
  166. )
  167. full_poll = poll.pop("full_poll", False)
  168. self._cached_state = self._cached_state | poll
  169. self._cached_state["updated_at"] = time()
  170. for entity in self._children:
  171. # clear non-persistant dps that were not in a full poll
  172. if full_poll:
  173. for dp in entity._config.dps():
  174. if not dp.persist and dp.id not in poll:
  175. self._cached_state.pop(dp.id, None)
  176. entity.async_schedule_update_ha_state()
  177. else:
  178. _LOGGER.debug(
  179. "%s received non data %s",
  180. self.name,
  181. log_json(poll),
  182. )
  183. _LOGGER.warning("%s receive loop has terminated", self.name)
  184. except Exception as t:
  185. _LOGGER.exception(
  186. "%s receive loop terminated by exception %s", self.name, t
  187. )
  188. @property
  189. def should_poll(self):
  190. return self._poll_only or self._temporary_poll or not self.has_returned_state
  191. def pause(self):
  192. self._temporary_poll = True
  193. def resume(self):
  194. self._temporary_poll = False
  195. async def async_receive(self):
  196. """Receive messages from a persistent connection asynchronously."""
  197. # If we didn't yet get any state from the device, we may need to
  198. # negotiate the protocol before making the connection persistent
  199. persist = not self.should_poll
  200. # flag to alternate updatedps and status calls to ensure we get
  201. # all dps updated
  202. dps_updated = False
  203. self._api.set_socketPersistent(persist)
  204. while self._running:
  205. try:
  206. last_cache = self._cached_state.get("updated_at", 0)
  207. now = time()
  208. full_poll = False
  209. if persist == self.should_poll:
  210. # use persistent connections after initial communication
  211. # has been established. Until then, we need to rotate
  212. # the protocol version, which seems to require a fresh
  213. # connection.
  214. persist = not self.should_poll
  215. self._api.set_socketPersistent(persist)
  216. if now - last_cache > self._CACHE_TIMEOUT:
  217. if (
  218. self._force_dps
  219. and not dps_updated
  220. and self._api_protocol_working
  221. ):
  222. poll = await self._retry_on_failed_connection(
  223. lambda: self._api.updatedps(self._force_dps),
  224. f"Failed to refresh device state for {self.name}",
  225. )
  226. dps_updated = True
  227. else:
  228. poll = await self._retry_on_failed_connection(
  229. lambda: self._api.status(),
  230. f"Failed to refresh device state for {self.name}",
  231. )
  232. dps_updated = False
  233. full_poll = True
  234. elif persist:
  235. await self._hass.async_add_executor_job(
  236. self._api.heartbeat,
  237. True,
  238. )
  239. poll = await self._hass.async_add_executor_job(
  240. self._api.receive,
  241. )
  242. else:
  243. asyncio.sleep(5)
  244. poll = None
  245. if poll:
  246. if "Error" in poll:
  247. _LOGGER.warning(
  248. "%s error reading: %s", self.name, poll["Error"]
  249. )
  250. if "Payload" in poll and poll["Payload"]:
  251. _LOGGER.info(
  252. "%s err payload: %s",
  253. self.name,
  254. poll["Payload"],
  255. )
  256. else:
  257. if "dps" in poll:
  258. poll = poll["dps"]
  259. poll["full_poll"] = full_poll
  260. yield poll
  261. await asyncio.sleep(0.1 if self.has_returned_state else 5)
  262. except asyncio.CancelledError:
  263. self._running = False
  264. # Close the persistent connection when exiting the loop
  265. self._api.set_socketPersistent(False)
  266. raise
  267. except Exception as t:
  268. _LOGGER.exception(
  269. "%s receive loop error %s:%s",
  270. self.name,
  271. type(t),
  272. t,
  273. )
  274. await asyncio.sleep(5)
  275. # Close the persistent connection when exiting the loop
  276. self._api.set_socketPersistent(False)
  277. async def async_possible_types(self):
  278. cached_state = self._get_cached_state()
  279. if len(cached_state) <= 1:
  280. # in case of device22 devices, we need to poll them with a dp
  281. # that exists on the device to get anything back. Most switch-like
  282. # devices have dp 1. Lights generally start from 20. 101 is where
  283. # vendor specific dps start. Between them, these three should cover
  284. # most devices.
  285. self._api.set_dpsUsed({"1": None, "20": None, "101": None})
  286. await self.async_refresh()
  287. cached_state = self._get_cached_state()
  288. for match in possible_matches(cached_state):
  289. yield match
  290. async def async_inferred_type(self):
  291. best_match = None
  292. best_quality = 0
  293. cached_state = self._get_cached_state()
  294. async for config in self.async_possible_types():
  295. quality = config.match_quality(cached_state)
  296. _LOGGER.info(
  297. "%s considering %s with quality %s",
  298. self.name,
  299. config.name,
  300. quality,
  301. )
  302. if quality > best_quality:
  303. best_quality = quality
  304. best_match = config
  305. if best_match is None:
  306. _LOGGER.warning(
  307. "Detection for %s with dps %s failed",
  308. self.name,
  309. log_json(cached_state),
  310. )
  311. return None
  312. return best_match.config_type
  313. async def async_refresh(self):
  314. _LOGGER.debug("Refreshing device state for %s", self.name)
  315. await self._retry_on_failed_connection(
  316. lambda: self._refresh_cached_state(),
  317. f"Failed to refresh device state for {self.name}.",
  318. )
  319. def get_property(self, dps_id):
  320. cached_state = self._get_cached_state()
  321. return cached_state.get(dps_id)
  322. async def async_set_property(self, dps_id, value):
  323. await self.async_set_properties({dps_id: value})
  324. def anticipate_property_value(self, dps_id, value):
  325. """
  326. Update a value in the cached state only. This is good for when you
  327. know the device will reflect a new state in the next update, but
  328. don't want to wait for that update for the device to represent
  329. this state.
  330. The anticipated value will be cleared with the next update.
  331. """
  332. self._cached_state[dps_id] = value
  333. def _reset_cached_state(self):
  334. self._cached_state = {"updated_at": 0}
  335. self._pending_updates = {}
  336. self._last_connection = 0
  337. def _refresh_cached_state(self):
  338. new_state = self._api.status()
  339. if new_state:
  340. self._cached_state = self._cached_state | new_state.get("dps", {})
  341. self._cached_state["updated_at"] = time()
  342. for entity in self._children:
  343. for dp in entity._config.dps():
  344. # Clear non-persistant dps that were not in the poll
  345. if not dp.persist and dp.id not in new_state.get("dps", {}):
  346. self._cached_state.pop(dp.id, None)
  347. entity.async_schedule_update_ha_state()
  348. _LOGGER.debug(
  349. "%s refreshed device state: %s",
  350. self.name,
  351. log_json(new_state),
  352. )
  353. _LOGGER.debug(
  354. "new state (incl pending): %s",
  355. log_json(self._get_cached_state()),
  356. )
  357. async def async_set_properties(self, properties):
  358. if len(properties) == 0:
  359. return
  360. self._add_properties_to_pending_updates(properties)
  361. await self._debounce_sending_updates()
  362. def _add_properties_to_pending_updates(self, properties):
  363. now = time()
  364. pending_updates = self._get_pending_updates()
  365. for key, value in properties.items():
  366. pending_updates[key] = {
  367. "value": value,
  368. "updated_at": now,
  369. "sent": False,
  370. }
  371. _LOGGER.debug(
  372. "%s new pending updates: %s",
  373. self.name,
  374. log_json(pending_updates),
  375. )
  376. async def _debounce_sending_updates(self):
  377. now = time()
  378. since = now - self._last_connection
  379. # set this now to avoid a race condition, it will be updated later
  380. # when the data is actally sent
  381. self._last_connection = now
  382. # Only delay a second if there was recently another command.
  383. # Otherwise delay 1ms, to keep things simple by reusing the
  384. # same send mechanism.
  385. waittime = 1 if since < 1.1 else 0.001
  386. await asyncio.sleep(waittime)
  387. await self._send_pending_updates()
  388. async def _send_pending_updates(self):
  389. pending_properties = self._get_unsent_properties()
  390. _LOGGER.debug(
  391. "%s sending dps update: %s",
  392. self.name,
  393. log_json(pending_properties),
  394. )
  395. await self._retry_on_failed_connection(
  396. lambda: self._set_values(pending_properties),
  397. "Failed to update device state.",
  398. )
  399. def _set_values(self, properties):
  400. try:
  401. self._lock.acquire()
  402. self._api.set_multiple_values(properties, nowait=True)
  403. self._cached_state["updated_at"] = 0
  404. now = time()
  405. self._last_connection = now
  406. pending_updates = self._get_pending_updates()
  407. for key in properties.keys():
  408. pending_updates[key]["updated_at"] = now
  409. pending_updates[key]["sent"] = True
  410. finally:
  411. self._lock.release()
  412. async def _retry_on_failed_connection(self, func, error_message):
  413. if self._api_protocol_version_index is None:
  414. await self._rotate_api_protocol_version()
  415. auto = (self._protocol_configured == "auto") and (
  416. not self._api_protocol_working
  417. )
  418. connections = (
  419. self._AUTO_CONNECTION_ATTEMPTS
  420. if auto
  421. else self._SINGLE_PROTO_CONNECTION_ATTEMPTS
  422. )
  423. for i in range(connections):
  424. try:
  425. if not self._hass.is_stopping:
  426. retval = await self._hass.async_add_executor_job(func)
  427. if type(retval) is dict and "Error" in retval:
  428. raise AttributeError(retval["Error"])
  429. self._api_protocol_working = True
  430. return retval
  431. except Exception as e:
  432. _LOGGER.debug(
  433. "Retrying after exception %s %s (%d/%d)",
  434. type(e),
  435. e,
  436. i,
  437. connections,
  438. )
  439. if i + 1 == connections:
  440. self._reset_cached_state()
  441. self._api_protocol_working = False
  442. for entity in self._children:
  443. entity.async_schedule_update_ha_state()
  444. _LOGGER.error(error_message)
  445. if not self._api_protocol_working:
  446. await self._rotate_api_protocol_version()
  447. def _get_cached_state(self):
  448. cached_state = self._cached_state.copy()
  449. return {**cached_state, **self._get_pending_properties()}
  450. def _get_pending_properties(self):
  451. return {
  452. key: property["value"]
  453. for key, property in self._get_pending_updates().items()
  454. }
  455. def _get_unsent_properties(self):
  456. return {
  457. key: info["value"]
  458. for key, info in self._get_pending_updates().items()
  459. if not info["sent"]
  460. }
  461. def _get_pending_updates(self):
  462. now = time()
  463. self._pending_updates = {
  464. key: value
  465. for key, value in self._pending_updates.items()
  466. if now - value.get("updated_at", 0) < self._FAKE_IT_TIMEOUT
  467. }
  468. return self._pending_updates
  469. async def _rotate_api_protocol_version(self):
  470. if self._api_protocol_version_index is None:
  471. try:
  472. self._api_protocol_version_index = API_PROTOCOL_VERSIONS.index(
  473. self._protocol_configured
  474. )
  475. except ValueError:
  476. self._api_protocol_version_index = 0
  477. # only rotate if configured as auto
  478. elif self._protocol_configured == "auto":
  479. self._api_protocol_version_index += 1
  480. if self._api_protocol_version_index >= len(API_PROTOCOL_VERSIONS):
  481. self._api_protocol_version_index = 0
  482. new_version = API_PROTOCOL_VERSIONS[self._api_protocol_version_index]
  483. _LOGGER.info(
  484. "Setting protocol version for %s to %0.1f",
  485. self.name,
  486. new_version,
  487. )
  488. await self._hass.async_add_executor_job(
  489. self._api.set_version,
  490. new_version,
  491. )
  492. @staticmethod
  493. def get_key_for_value(obj, value, fallback=None):
  494. keys = list(obj.keys())
  495. values = list(obj.values())
  496. return keys[values.index(value)] if value in values else fallback
  497. def setup_device(hass: HomeAssistant, config: dict):
  498. """Setup a tuya device based on passed in config."""
  499. _LOGGER.info("Creating device: %s", config[CONF_DEVICE_ID])
  500. hass.data[DOMAIN] = hass.data.get(DOMAIN, {})
  501. device = TuyaLocalDevice(
  502. config[CONF_NAME],
  503. config[CONF_DEVICE_ID],
  504. config[CONF_HOST],
  505. config[CONF_LOCAL_KEY],
  506. config[CONF_PROTOCOL_VERSION],
  507. hass,
  508. config[CONF_POLL_ONLY],
  509. )
  510. hass.data[DOMAIN][config[CONF_DEVICE_ID]] = {"device": device}
  511. return device
  512. async def async_delete_device(hass: HomeAssistant, config: dict):
  513. _LOGGER.info("Deleting device: %s", config[CONF_DEVICE_ID])
  514. await hass.data[DOMAIN][config[CONF_DEVICE_ID]]["device"].async_stop()
  515. del hass.data[DOMAIN][config[CONF_DEVICE_ID]]["device"]