Refactor handling of device updates in ESPHome (#112864)
This commit is contained in:
parent
57ce0f77ed
commit
f1b5dcdd1b
5 changed files with 58 additions and 80 deletions
|
@ -22,7 +22,6 @@ from homeassistant.helpers import entity_platform
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.helpers.device_registry as dr
|
||||
from homeassistant.helpers.device_registry import DeviceInfo
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
|
@ -205,25 +204,19 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
async def async_added_to_hass(self) -> None:
|
||||
"""Register callbacks."""
|
||||
entry_data = self._entry_data
|
||||
hass = self.hass
|
||||
key = self._key
|
||||
static_info = self._static_info
|
||||
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(
|
||||
hass,
|
||||
entry_data.signal_device_updated,
|
||||
entry_data.async_subscribe_device_updated(
|
||||
self._on_device_update,
|
||||
)
|
||||
)
|
||||
self.async_on_remove(
|
||||
entry_data.async_subscribe_state_update(
|
||||
self._state_type, key, self._on_state_update
|
||||
self._state_type, self._key, self._on_state_update
|
||||
)
|
||||
)
|
||||
self.async_on_remove(
|
||||
entry_data.async_register_key_static_info_updated_callback(
|
||||
static_info, self._on_static_info_update
|
||||
self._static_info, self._on_static_info_update
|
||||
)
|
||||
)
|
||||
self._update_state_from_entry_data()
|
||||
|
|
|
@ -108,18 +108,17 @@ class RuntimeEntryData:
|
|||
device_info: DeviceInfo | None = None
|
||||
bluetooth_device: ESPHomeBluetoothDevice | None = None
|
||||
api_version: APIVersion = field(default_factory=APIVersion)
|
||||
cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list)
|
||||
disconnect_callbacks: set[Callable[[], None]] = field(default_factory=set)
|
||||
state_subscriptions: dict[
|
||||
tuple[type[EntityState], int], Callable[[], None]
|
||||
] = field(default_factory=dict)
|
||||
cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||
disconnect_callbacks: set[CALLBACK_TYPE] = field(default_factory=set)
|
||||
state_subscriptions: dict[tuple[type[EntityState], int], CALLBACK_TYPE] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
|
||||
loaded_platforms: set[Platform] = field(default_factory=set)
|
||||
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_storage_contents: StoreData | None = None
|
||||
_pending_storage: Callable[[], StoreData] | None = None
|
||||
assist_pipeline_update_callbacks: list[Callable[[], None]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
assist_pipeline_update_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||
assist_pipeline_state: bool = False
|
||||
entity_info_callbacks: dict[
|
||||
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
|
||||
|
@ -143,11 +142,6 @@ class RuntimeEntryData:
|
|||
"_", " "
|
||||
)
|
||||
|
||||
@property
|
||||
def signal_device_updated(self) -> str:
|
||||
"""Return the signal to listen to for core device state update."""
|
||||
return f"esphome_{self.entry_id}_on_device_update"
|
||||
|
||||
@property
|
||||
def signal_static_info_updated(self) -> str:
|
||||
"""Return the signal to listen to for updates on static info."""
|
||||
|
@ -216,15 +210,15 @@ class RuntimeEntryData:
|
|||
|
||||
@callback
|
||||
def async_subscribe_assist_pipeline_update(
|
||||
self, update_callback: Callable[[], None]
|
||||
) -> Callable[[], None]:
|
||||
self, update_callback: CALLBACK_TYPE
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Subscribe to assist pipeline updates."""
|
||||
self.assist_pipeline_update_callbacks.append(update_callback)
|
||||
return partial(self._async_unsubscribe_assist_pipeline_update, update_callback)
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe_assist_pipeline_update(
|
||||
self, update_callback: Callable[[], None]
|
||||
self, update_callback: CALLBACK_TYPE
|
||||
) -> None:
|
||||
"""Unsubscribe to assist pipeline updates."""
|
||||
self.assist_pipeline_update_callbacks.remove(update_callback)
|
||||
|
@ -307,13 +301,24 @@ class RuntimeEntryData:
|
|||
# Then send dispatcher event
|
||||
async_dispatcher_send(hass, self.signal_static_info_updated, infos)
|
||||
|
||||
@callback
|
||||
def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE:
|
||||
"""Subscribe to state updates."""
|
||||
self.device_update_subscriptions.add(callback_)
|
||||
return partial(self._async_unsubscribe_device_update, callback_)
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe_device_update(self, callback_: CALLBACK_TYPE) -> None:
|
||||
"""Unsubscribe to device updates."""
|
||||
self.device_update_subscriptions.remove(callback_)
|
||||
|
||||
@callback
|
||||
def async_subscribe_state_update(
|
||||
self,
|
||||
state_type: type[EntityState],
|
||||
state_key: int,
|
||||
entity_callback: Callable[[], None],
|
||||
) -> Callable[[], None]:
|
||||
entity_callback: CALLBACK_TYPE,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Subscribe to state updates."""
|
||||
subscription_key = (state_type, state_key)
|
||||
self.state_subscriptions[subscription_key] = entity_callback
|
||||
|
@ -359,9 +364,10 @@ class RuntimeEntryData:
|
|||
_LOGGER.exception("Error while calling subscription: %s", ex)
|
||||
|
||||
@callback
|
||||
def async_update_device_state(self, hass: HomeAssistant) -> None:
|
||||
def async_update_device_state(self) -> None:
|
||||
"""Distribute an update of a core device state like availability."""
|
||||
async_dispatcher_send(hass, self.signal_device_updated)
|
||||
for callback_ in self.device_update_subscriptions.copy():
|
||||
callback_()
|
||||
|
||||
async def async_load_from_store(self) -> tuple[list[EntityInfo], list[UserService]]:
|
||||
"""Load the retained data from store and return de-serialized data."""
|
||||
|
|
|
@ -455,7 +455,7 @@ class ESPHomeManager:
|
|||
|
||||
self.device_id = _async_setup_device_registry(hass, entry, entry_data)
|
||||
|
||||
entry_data.async_update_device_state(hass)
|
||||
entry_data.async_update_device_state()
|
||||
await entry_data.async_update_static_infos(
|
||||
hass, entry, entity_infos, device_info.mac_address
|
||||
)
|
||||
|
@ -510,7 +510,7 @@ class ESPHomeManager:
|
|||
# since it generates a lot of state changed events and database
|
||||
# writes when we already know we're shutting down and the state
|
||||
# will be cleared anyway.
|
||||
entry_data.async_update_device_state(hass)
|
||||
entry_data.async_update_device_state()
|
||||
|
||||
async def on_connect_error(self, err: Exception) -> None:
|
||||
"""Start reauth flow if appropriate connect error type."""
|
||||
|
|
|
@ -61,9 +61,7 @@ async def async_setup_entry(
|
|||
return
|
||||
|
||||
unsubs = [
|
||||
async_dispatcher_connect(
|
||||
hass, entry_data.signal_device_updated, _async_setup_update_entity
|
||||
),
|
||||
entry_data.async_subscribe_device_updated(_async_setup_update_entity),
|
||||
dashboard.async_add_listener(_async_setup_update_entity),
|
||||
]
|
||||
|
||||
|
@ -159,11 +157,7 @@ class ESPHomeUpdateEntity(CoordinatorEntity[ESPHomeDashboard], UpdateEntity):
|
|||
)
|
||||
)
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(
|
||||
hass,
|
||||
entry_data.signal_device_updated,
|
||||
self._handle_device_update,
|
||||
)
|
||||
entry_data.async_subscribe_device_updated(self._handle_device_update)
|
||||
)
|
||||
|
||||
async def async_install(
|
||||
|
|
|
@ -208,15 +208,25 @@ async def test_update_static_info(
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_disconnect_state", [(True, STATE_ON), (False, STATE_UNAVAILABLE)]
|
||||
("expected_disconnect", "expected_state", "has_deep_sleep"),
|
||||
[
|
||||
(True, STATE_ON, False),
|
||||
(False, STATE_UNAVAILABLE, False),
|
||||
(True, STATE_ON, True),
|
||||
(False, STATE_ON, True),
|
||||
],
|
||||
)
|
||||
async def test_update_device_state_for_availability(
|
||||
hass: HomeAssistant,
|
||||
stub_reconnect,
|
||||
expected_disconnect_state: tuple[bool, str],
|
||||
mock_config_entry,
|
||||
mock_device_info,
|
||||
expected_disconnect: bool,
|
||||
expected_state: str,
|
||||
has_deep_sleep: bool,
|
||||
mock_dashboard,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test ESPHome update entity changes availability with the device."""
|
||||
mock_dashboard["configured"] = [
|
||||
|
@ -226,46 +236,21 @@ async def test_update_device_state_for_availability(
|
|||
},
|
||||
]
|
||||
await async_get_dashboard(hass).async_refresh()
|
||||
|
||||
signal_device_updated = f"esphome_{mock_config_entry.entry_id}_on_device_update"
|
||||
runtime_data = Mock(
|
||||
available=True,
|
||||
expected_disconnect=False,
|
||||
device_info=mock_device_info,
|
||||
signal_device_updated=signal_device_updated,
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={"has_deep_sleep": has_deep_sleep},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.update.DomainData.get_entry_data",
|
||||
return_value=runtime_data,
|
||||
):
|
||||
assert await hass.config_entries.async_forward_entry_setup(
|
||||
mock_config_entry, "update"
|
||||
)
|
||||
|
||||
state = hass.states.get("update.none_firmware")
|
||||
state = hass.states.get("update.test_firmware")
|
||||
assert state is not None
|
||||
assert state.state == "on"
|
||||
|
||||
expected_disconnect, expected_state = expected_disconnect_state
|
||||
|
||||
runtime_data.available = False
|
||||
runtime_data.expected_disconnect = expected_disconnect
|
||||
async_dispatcher_send(hass, signal_device_updated)
|
||||
|
||||
state = hass.states.get("update.none_firmware")
|
||||
assert state.state == STATE_ON
|
||||
await mock_device.mock_disconnect(expected_disconnect)
|
||||
state = hass.states.get("update.test_firmware")
|
||||
assert state.state == expected_state
|
||||
|
||||
# Deep sleep devices should still be available
|
||||
runtime_data.device_info = dataclasses.replace(
|
||||
runtime_data.device_info, has_deep_sleep=True
|
||||
)
|
||||
|
||||
async_dispatcher_send(hass, signal_device_updated)
|
||||
|
||||
state = hass.states.get("update.none_firmware")
|
||||
assert state.state == "on"
|
||||
|
||||
|
||||
async def test_update_entity_dashboard_not_available_startup(
|
||||
hass: HomeAssistant,
|
||||
|
|
Loading…
Add table
Reference in a new issue