Better teardown and setup of Roborock connections (#106092)

Co-authored-by: Robert Resch <robert@resch.dev>
This commit is contained in:
Luke Lashley 2024-02-12 03:37:37 -05:00 committed by GitHub
parent 2516eafba6
commit 33cdcce191
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 47 additions and 27 deletions

View file

@ -115,6 +115,7 @@ async def setup_device(
device.name,
)
_LOGGER.debug(err)
await mqtt_client.async_release()
raise err
coordinator = RoborockDataUpdateCoordinator(
hass, device, networking, product_info, mqtt_client
@ -130,6 +131,7 @@ async def setup_device(
try:
await coordinator.async_config_entry_first_refresh()
except ConfigEntryNotReady as ex:
await coordinator.release()
if isinstance(coordinator.api, RoborockMqttClient):
_LOGGER.warning(
"Not setting up %s because the we failed to get data for the first time using the online client. "
@ -158,14 +160,10 @@ async def setup_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Handle removal of an entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
await asyncio.gather(
*(
coordinator.release()
for coordinator in hass.data[DOMAIN][entry.entry_id].values()
)
)
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
release_tasks = set()
for coordinator in hass.data[DOMAIN][entry.entry_id].values():
release_tasks.add(coordinator.release())
hass.data[DOMAIN].pop(entry.entry_id)
await asyncio.gather(*release_tasks)
return unload_ok

View file

@ -79,7 +79,8 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
async def release(self) -> None:
"""Disconnect from API."""
await self.api.async_disconnect()
await self.api.async_release()
await self.cloud_api.async_release()
async def _update_device_prop(self) -> None:
"""Update device properties."""

View file

@ -1,5 +1,4 @@
"""Support for Roborock device base class."""
from typing import Any
from roborock.api import AttributeCache, RoborockClient
@ -7,6 +6,7 @@ from roborock.cloud_api import RoborockMqttClient
from roborock.command_cache import CacheableAttribute
from roborock.containers import Consumable, Status
from roborock.exceptions import RoborockException
from roborock.roborock_message import RoborockDataProtocol
from roborock.roborock_typing import RoborockCommand
from homeassistant.exceptions import HomeAssistantError
@ -24,7 +24,10 @@ class RoborockEntity(Entity):
_attr_has_entity_name = True
def __init__(
self, unique_id: str, device_info: DeviceInfo, api: RoborockClient
self,
unique_id: str,
device_info: DeviceInfo,
api: RoborockClient,
) -> None:
"""Initialize the coordinated Roborock Device."""
self._attr_unique_id = unique_id
@ -75,6 +78,9 @@ class RoborockCoordinatedEntity(
self,
unique_id: str,
coordinator: RoborockDataUpdateCoordinator,
listener_request: list[RoborockDataProtocol]
| RoborockDataProtocol
| None = None,
) -> None:
"""Initialize the coordinated Roborock Device."""
RoborockEntity.__init__(
@ -85,6 +91,23 @@ class RoborockCoordinatedEntity(
)
CoordinatorEntity.__init__(self, coordinator=coordinator)
self._attr_unique_id = unique_id
if isinstance(listener_request, RoborockDataProtocol):
listener_request = [listener_request]
self.listener_requests = listener_request or []
async def async_added_to_hass(self) -> None:
"""Add listeners when the device is added to hass."""
await super().async_added_to_hass()
for listener_request in self.listener_requests:
self.api.add_listener(
listener_request, self._update_from_listener, cache=self.api.cache
)
async def async_will_remove_from_hass(self) -> None:
"""Remove listeners when the device is removed from hass."""
for listener_request in self.listener_requests:
self.api.remove_listener(listener_request, self._update_from_listener)
await super().async_will_remove_from_hass()
@property
def _device_status(self) -> Status:
@ -107,7 +130,7 @@ class RoborockCoordinatedEntity(
await self.coordinator.async_refresh()
return res
def _update_from_listener(self, value: Status | Consumable):
def _update_from_listener(self, value: Status | Consumable) -> None:
"""Update the status or consumable data from a listener and then write the new entity state."""
if isinstance(value, Status):
self.coordinator.roborock_device_info.props.status = value

View file

@ -107,10 +107,8 @@ class RoborockSelectEntity(RoborockCoordinatedEntity, SelectEntity):
) -> None:
"""Create a select entity."""
self.entity_description = entity_description
super().__init__(unique_id, coordinator)
super().__init__(unique_id, coordinator, entity_description.protocol_listener)
self._attr_options = options
if (protocol := self.entity_description.protocol_listener) is not None:
self.api.add_listener(protocol, self._update_from_listener, self.api.cache)
async def async_select_option(self, option: str) -> None:
"""Set the option."""

View file

@ -232,10 +232,8 @@ class RoborockSensorEntity(RoborockCoordinatedEntity, SensorEntity):
description: RoborockSensorDescription,
) -> None:
"""Initialize the entity."""
super().__init__(unique_id, coordinator)
self.entity_description = description
if (protocol := self.entity_description.protocol_listener) is not None:
self.api.add_listener(protocol, self._update_from_listener, self.api.cache)
super().__init__(unique_id, coordinator, description.protocol_listener)
@property
def native_value(self) -> StateType | datetime.datetime:

View file

@ -92,14 +92,16 @@ class RoborockVacuum(RoborockCoordinatedEntity, StateVacuumEntity):
) -> None:
"""Initialize a vacuum."""
StateVacuumEntity.__init__(self)
RoborockCoordinatedEntity.__init__(self, unique_id, coordinator)
RoborockCoordinatedEntity.__init__(
self,
unique_id,
coordinator,
listener_request=[
RoborockDataProtocol.FAN_POWER,
RoborockDataProtocol.STATE,
],
)
self._attr_fan_speed_list = self._device_status.fan_power_options
self.api.add_listener(
RoborockDataProtocol.FAN_POWER, self._update_from_listener, self.api.cache
)
self.api.add_listener(
RoborockDataProtocol.STATE, self._update_from_listener, self.api.cache
)
@property
def state(self) -> str | None:

View file

@ -18,7 +18,7 @@ async def test_unload_entry(
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert setup_entry.state is ConfigEntryState.LOADED
with patch(
"homeassistant.components.roborock.coordinator.RoborockLocalClient.async_disconnect"
"homeassistant.components.roborock.coordinator.RoborockLocalClient.async_release"
) as mock_disconnect:
assert await hass.config_entries.async_unload(setup_entry.entry_id)
await hass.async_block_till_done()