Better teardown and setup of Roborock connections (#106092)
Co-authored-by: Robert Resch <robert@resch.dev>
This commit is contained in:
parent
2516eafba6
commit
33cdcce191
7 changed files with 47 additions and 27 deletions
|
@ -115,6 +115,7 @@ async def setup_device(
|
|||
device.name,
|
||||
)
|
||||
_LOGGER.debug(err)
|
||||
await mqtt_client.async_release()
|
||||
raise err
|
||||
coordinator = RoborockDataUpdateCoordinator(
|
||||
hass, device, networking, product_info, mqtt_client
|
||||
|
@ -130,6 +131,7 @@ async def setup_device(
|
|||
try:
|
||||
await coordinator.async_config_entry_first_refresh()
|
||||
except ConfigEntryNotReady as ex:
|
||||
await coordinator.release()
|
||||
if isinstance(coordinator.api, RoborockMqttClient):
|
||||
_LOGGER.warning(
|
||||
"Not setting up %s because the we failed to get data for the first time using the online client. "
|
||||
|
@ -158,14 +160,10 @@ async def setup_device(
|
|||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Handle removal of an entry."""
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
if unload_ok:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
coordinator.release()
|
||||
for coordinator in hass.data[DOMAIN][entry.entry_id].values()
|
||||
)
|
||||
)
|
||||
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
release_tasks = set()
|
||||
for coordinator in hass.data[DOMAIN][entry.entry_id].values():
|
||||
release_tasks.add(coordinator.release())
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
|
||||
await asyncio.gather(*release_tasks)
|
||||
return unload_ok
|
||||
|
|
|
@ -79,7 +79,8 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
|
|||
|
||||
async def release(self) -> None:
|
||||
"""Disconnect from API."""
|
||||
await self.api.async_disconnect()
|
||||
await self.api.async_release()
|
||||
await self.cloud_api.async_release()
|
||||
|
||||
async def _update_device_prop(self) -> None:
|
||||
"""Update device properties."""
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""Support for Roborock device base class."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from roborock.api import AttributeCache, RoborockClient
|
||||
|
@ -7,6 +6,7 @@ from roborock.cloud_api import RoborockMqttClient
|
|||
from roborock.command_cache import CacheableAttribute
|
||||
from roborock.containers import Consumable, Status
|
||||
from roborock.exceptions import RoborockException
|
||||
from roborock.roborock_message import RoborockDataProtocol
|
||||
from roborock.roborock_typing import RoborockCommand
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
@ -24,7 +24,10 @@ class RoborockEntity(Entity):
|
|||
_attr_has_entity_name = True
|
||||
|
||||
def __init__(
|
||||
self, unique_id: str, device_info: DeviceInfo, api: RoborockClient
|
||||
self,
|
||||
unique_id: str,
|
||||
device_info: DeviceInfo,
|
||||
api: RoborockClient,
|
||||
) -> None:
|
||||
"""Initialize the coordinated Roborock Device."""
|
||||
self._attr_unique_id = unique_id
|
||||
|
@ -75,6 +78,9 @@ class RoborockCoordinatedEntity(
|
|||
self,
|
||||
unique_id: str,
|
||||
coordinator: RoborockDataUpdateCoordinator,
|
||||
listener_request: list[RoborockDataProtocol]
|
||||
| RoborockDataProtocol
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""Initialize the coordinated Roborock Device."""
|
||||
RoborockEntity.__init__(
|
||||
|
@ -85,6 +91,23 @@ class RoborockCoordinatedEntity(
|
|||
)
|
||||
CoordinatorEntity.__init__(self, coordinator=coordinator)
|
||||
self._attr_unique_id = unique_id
|
||||
if isinstance(listener_request, RoborockDataProtocol):
|
||||
listener_request = [listener_request]
|
||||
self.listener_requests = listener_request or []
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Add listeners when the device is added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
for listener_request in self.listener_requests:
|
||||
self.api.add_listener(
|
||||
listener_request, self._update_from_listener, cache=self.api.cache
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Remove listeners when the device is removed from hass."""
|
||||
for listener_request in self.listener_requests:
|
||||
self.api.remove_listener(listener_request, self._update_from_listener)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
@property
|
||||
def _device_status(self) -> Status:
|
||||
|
@ -107,7 +130,7 @@ class RoborockCoordinatedEntity(
|
|||
await self.coordinator.async_refresh()
|
||||
return res
|
||||
|
||||
def _update_from_listener(self, value: Status | Consumable):
|
||||
def _update_from_listener(self, value: Status | Consumable) -> None:
|
||||
"""Update the status or consumable data from a listener and then write the new entity state."""
|
||||
if isinstance(value, Status):
|
||||
self.coordinator.roborock_device_info.props.status = value
|
||||
|
|
|
@ -107,10 +107,8 @@ class RoborockSelectEntity(RoborockCoordinatedEntity, SelectEntity):
|
|||
) -> None:
|
||||
"""Create a select entity."""
|
||||
self.entity_description = entity_description
|
||||
super().__init__(unique_id, coordinator)
|
||||
super().__init__(unique_id, coordinator, entity_description.protocol_listener)
|
||||
self._attr_options = options
|
||||
if (protocol := self.entity_description.protocol_listener) is not None:
|
||||
self.api.add_listener(protocol, self._update_from_listener, self.api.cache)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Set the option."""
|
||||
|
|
|
@ -232,10 +232,8 @@ class RoborockSensorEntity(RoborockCoordinatedEntity, SensorEntity):
|
|||
description: RoborockSensorDescription,
|
||||
) -> None:
|
||||
"""Initialize the entity."""
|
||||
super().__init__(unique_id, coordinator)
|
||||
self.entity_description = description
|
||||
if (protocol := self.entity_description.protocol_listener) is not None:
|
||||
self.api.add_listener(protocol, self._update_from_listener, self.api.cache)
|
||||
super().__init__(unique_id, coordinator, description.protocol_listener)
|
||||
|
||||
@property
|
||||
def native_value(self) -> StateType | datetime.datetime:
|
||||
|
|
|
@ -92,14 +92,16 @@ class RoborockVacuum(RoborockCoordinatedEntity, StateVacuumEntity):
|
|||
) -> None:
|
||||
"""Initialize a vacuum."""
|
||||
StateVacuumEntity.__init__(self)
|
||||
RoborockCoordinatedEntity.__init__(self, unique_id, coordinator)
|
||||
RoborockCoordinatedEntity.__init__(
|
||||
self,
|
||||
unique_id,
|
||||
coordinator,
|
||||
listener_request=[
|
||||
RoborockDataProtocol.FAN_POWER,
|
||||
RoborockDataProtocol.STATE,
|
||||
],
|
||||
)
|
||||
self._attr_fan_speed_list = self._device_status.fan_power_options
|
||||
self.api.add_listener(
|
||||
RoborockDataProtocol.FAN_POWER, self._update_from_listener, self.api.cache
|
||||
)
|
||||
self.api.add_listener(
|
||||
RoborockDataProtocol.STATE, self._update_from_listener, self.api.cache
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> str | None:
|
||||
|
|
|
@ -18,7 +18,7 @@ async def test_unload_entry(
|
|||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert setup_entry.state is ConfigEntryState.LOADED
|
||||
with patch(
|
||||
"homeassistant.components.roborock.coordinator.RoborockLocalClient.async_disconnect"
|
||||
"homeassistant.components.roborock.coordinator.RoborockLocalClient.async_release"
|
||||
) as mock_disconnect:
|
||||
assert await hass.config_entries.async_unload(setup_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
|
Loading…
Add table
Reference in a new issue