Refactor handling of device updates in ESPHome (#112864)

This commit is contained in:
J. Nick Koston 2024-03-09 20:30:17 -10:00 committed by GitHub
parent 57ce0f77ed
commit f1b5dcdd1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 58 additions and 80 deletions

View file

@ -22,7 +22,6 @@ from homeassistant.helpers import entity_platform
import homeassistant.helpers.config_validation as cv
import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -205,25 +204,19 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
entry_data = self._entry_data
hass = self.hass
key = self._key
static_info = self._static_info
self.async_on_remove(
async_dispatcher_connect(
hass,
entry_data.signal_device_updated,
entry_data.async_subscribe_device_updated(
self._on_device_update,
)
)
self.async_on_remove(
entry_data.async_subscribe_state_update(
self._state_type, key, self._on_state_update
self._state_type, self._key, self._on_state_update
)
)
self.async_on_remove(
entry_data.async_register_key_static_info_updated_callback(
static_info, self._on_static_info_update
self._static_info, self._on_static_info_update
)
)
self._update_state_from_entry_data()

View file

@ -108,18 +108,17 @@ class RuntimeEntryData:
device_info: DeviceInfo | None = None
bluetooth_device: ESPHomeBluetoothDevice | None = None
api_version: APIVersion = field(default_factory=APIVersion)
cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list)
disconnect_callbacks: set[Callable[[], None]] = field(default_factory=set)
state_subscriptions: dict[
tuple[type[EntityState], int], Callable[[], None]
] = field(default_factory=dict)
cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
disconnect_callbacks: set[CALLBACK_TYPE] = field(default_factory=set)
state_subscriptions: dict[tuple[type[EntityState], int], CALLBACK_TYPE] = field(
default_factory=dict
)
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
loaded_platforms: set[Platform] = field(default_factory=set)
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_storage_contents: StoreData | None = None
_pending_storage: Callable[[], StoreData] | None = None
assist_pipeline_update_callbacks: list[Callable[[], None]] = field(
default_factory=list
)
assist_pipeline_update_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
assist_pipeline_state: bool = False
entity_info_callbacks: dict[
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
@ -143,11 +142,6 @@ class RuntimeEntryData:
"_", " "
)
@property
def signal_device_updated(self) -> str:
"""Return the signal to listen to for core device state update."""
return f"esphome_{self.entry_id}_on_device_update"
@property
def signal_static_info_updated(self) -> str:
"""Return the signal to listen to for updates on static info."""
@ -216,15 +210,15 @@ class RuntimeEntryData:
@callback
def async_subscribe_assist_pipeline_update(
self, update_callback: Callable[[], None]
) -> Callable[[], None]:
self, update_callback: CALLBACK_TYPE
) -> CALLBACK_TYPE:
"""Subscribe to assist pipeline updates."""
self.assist_pipeline_update_callbacks.append(update_callback)
return partial(self._async_unsubscribe_assist_pipeline_update, update_callback)
@callback
def _async_unsubscribe_assist_pipeline_update(
self, update_callback: Callable[[], None]
self, update_callback: CALLBACK_TYPE
) -> None:
"""Unsubscribe to assist pipeline updates."""
self.assist_pipeline_update_callbacks.remove(update_callback)
@ -307,13 +301,24 @@ class RuntimeEntryData:
# Then send dispatcher event
async_dispatcher_send(hass, self.signal_static_info_updated, infos)
@callback
def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE:
"""Subscribe to state updates."""
self.device_update_subscriptions.add(callback_)
return partial(self._async_unsubscribe_device_update, callback_)
@callback
def _async_unsubscribe_device_update(self, callback_: CALLBACK_TYPE) -> None:
"""Unsubscribe to device updates."""
self.device_update_subscriptions.remove(callback_)
@callback
def async_subscribe_state_update(
self,
state_type: type[EntityState],
state_key: int,
entity_callback: Callable[[], None],
) -> Callable[[], None]:
entity_callback: CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Subscribe to state updates."""
subscription_key = (state_type, state_key)
self.state_subscriptions[subscription_key] = entity_callback
@ -359,9 +364,10 @@ class RuntimeEntryData:
_LOGGER.exception("Error while calling subscription: %s", ex)
@callback
def async_update_device_state(self, hass: HomeAssistant) -> None:
def async_update_device_state(self) -> None:
"""Distribute an update of a core device state like availability."""
async_dispatcher_send(hass, self.signal_device_updated)
for callback_ in self.device_update_subscriptions.copy():
callback_()
async def async_load_from_store(self) -> tuple[list[EntityInfo], list[UserService]]:
"""Load the retained data from store and return de-serialized data."""

View file

@ -455,7 +455,7 @@ class ESPHomeManager:
self.device_id = _async_setup_device_registry(hass, entry, entry_data)
entry_data.async_update_device_state(hass)
entry_data.async_update_device_state()
await entry_data.async_update_static_infos(
hass, entry, entity_infos, device_info.mac_address
)
@ -510,7 +510,7 @@ class ESPHomeManager:
# since it generates a lot of state changed events and database
# writes when we already know we're shutting down and the state
# will be cleared anyway.
entry_data.async_update_device_state(hass)
entry_data.async_update_device_state()
async def on_connect_error(self, err: Exception) -> None:
"""Start reauth flow if appropriate connect error type."""

View file

@ -61,9 +61,7 @@ async def async_setup_entry(
return
unsubs = [
async_dispatcher_connect(
hass, entry_data.signal_device_updated, _async_setup_update_entity
),
entry_data.async_subscribe_device_updated(_async_setup_update_entity),
dashboard.async_add_listener(_async_setup_update_entity),
]
@ -159,11 +157,7 @@ class ESPHomeUpdateEntity(CoordinatorEntity[ESPHomeDashboard], UpdateEntity):
)
)
self.async_on_remove(
async_dispatcher_connect(
hass,
entry_data.signal_device_updated,
self._handle_device_update,
)
entry_data.async_subscribe_device_updated(self._handle_device_update)
)
async def async_install(

View file

@ -208,15 +208,25 @@ async def test_update_static_info(
@pytest.mark.parametrize(
"expected_disconnect_state", [(True, STATE_ON), (False, STATE_UNAVAILABLE)]
("expected_disconnect", "expected_state", "has_deep_sleep"),
[
(True, STATE_ON, False),
(False, STATE_UNAVAILABLE, False),
(True, STATE_ON, True),
(False, STATE_ON, True),
],
)
async def test_update_device_state_for_availability(
hass: HomeAssistant,
stub_reconnect,
expected_disconnect_state: tuple[bool, str],
mock_config_entry,
mock_device_info,
expected_disconnect: bool,
expected_state: str,
has_deep_sleep: bool,
mock_dashboard,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test ESPHome update entity changes availability with the device."""
mock_dashboard["configured"] = [
@ -226,46 +236,21 @@ async def test_update_device_state_for_availability(
},
]
await async_get_dashboard(hass).async_refresh()
signal_device_updated = f"esphome_{mock_config_entry.entry_id}_on_device_update"
runtime_data = Mock(
available=True,
expected_disconnect=False,
device_info=mock_device_info,
signal_device_updated=signal_device_updated,
mock_device = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={"has_deep_sleep": has_deep_sleep},
)
with patch(
"homeassistant.components.esphome.update.DomainData.get_entry_data",
return_value=runtime_data,
):
assert await hass.config_entries.async_forward_entry_setup(
mock_config_entry, "update"
)
state = hass.states.get("update.none_firmware")
state = hass.states.get("update.test_firmware")
assert state is not None
assert state.state == "on"
expected_disconnect, expected_state = expected_disconnect_state
runtime_data.available = False
runtime_data.expected_disconnect = expected_disconnect
async_dispatcher_send(hass, signal_device_updated)
state = hass.states.get("update.none_firmware")
assert state.state == STATE_ON
await mock_device.mock_disconnect(expected_disconnect)
state = hass.states.get("update.test_firmware")
assert state.state == expected_state
# Deep sleep devices should still be available
runtime_data.device_info = dataclasses.replace(
runtime_data.device_info, has_deep_sleep=True
)
async_dispatcher_send(hass, signal_device_updated)
state = hass.states.get("update.none_firmware")
assert state.state == "on"
async def test_update_entity_dashboard_not_available_startup(
hass: HomeAssistant,