Refactor ESPHome manager to avoid sending signals in tests (#116033)
This commit is contained in:
parent
220dc1f125
commit
d0f5e40b19
4 changed files with 48 additions and 46 deletions
|
@ -49,9 +49,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.util.signal_type import SignalType
|
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .dashboard import async_get_dashboard
|
from .dashboard import async_get_dashboard
|
||||||
|
@ -126,6 +124,9 @@ class RuntimeEntryData:
|
||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
|
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
|
||||||
|
static_info_update_subscriptions: set[Callable[[list[EntityInfo]], None]] = field(
|
||||||
|
default_factory=set
|
||||||
|
)
|
||||||
loaded_platforms: set[Platform] = field(default_factory=set)
|
loaded_platforms: set[Platform] = field(default_factory=set)
|
||||||
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||||
_storage_contents: StoreData | None = None
|
_storage_contents: StoreData | None = None
|
||||||
|
@ -154,11 +155,6 @@ class RuntimeEntryData:
|
||||||
"_", " "
|
"_", " "
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def signal_static_info_updated(self) -> SignalType[list[EntityInfo]]:
|
|
||||||
"""Return the signal to listen to for updates on static info."""
|
|
||||||
return SignalType(f"esphome_{self.entry_id}_on_list")
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_static_info_callback(
|
def async_register_static_info_callback(
|
||||||
self,
|
self,
|
||||||
|
@ -303,8 +299,9 @@ class RuntimeEntryData:
|
||||||
for callback_ in callbacks_:
|
for callback_ in callbacks_:
|
||||||
callback_(entity_infos)
|
callback_(entity_infos)
|
||||||
|
|
||||||
# Then send dispatcher event
|
# Finally update static info subscriptions
|
||||||
async_dispatcher_send(hass, self.signal_static_info_updated, infos)
|
for callback_ in self.static_info_update_subscriptions:
|
||||||
|
callback_(infos)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE:
|
def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE:
|
||||||
|
@ -317,6 +314,21 @@ class RuntimeEntryData:
|
||||||
"""Unsubscribe to device updates."""
|
"""Unsubscribe to device updates."""
|
||||||
self.device_update_subscriptions.remove(callback_)
|
self.device_update_subscriptions.remove(callback_)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_subscribe_static_info_updated(
|
||||||
|
self, callback_: Callable[[list[EntityInfo]], None]
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Subscribe to static info updates."""
|
||||||
|
self.static_info_update_subscriptions.add(callback_)
|
||||||
|
return partial(self._async_unsubscribe_static_info_updated, callback_)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_unsubscribe_static_info_updated(
|
||||||
|
self, callback_: Callable[[list[EntityInfo]], None]
|
||||||
|
) -> None:
|
||||||
|
"""Unsubscribe to static info updates."""
|
||||||
|
self.static_info_update_subscriptions.remove(callback_)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_subscribe_state_update(
|
def async_subscribe_state_update(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -17,7 +17,6 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr
|
from homeassistant.helpers import device_registry as dr
|
||||||
from homeassistant.helpers.device_registry import DeviceInfo
|
from homeassistant.helpers.device_registry import DeviceInfo
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
|
@ -149,14 +148,9 @@ class ESPHomeUpdateEntity(CoordinatorEntity[ESPHomeDashboard], UpdateEntity):
|
||||||
async def async_added_to_hass(self) -> None:
|
async def async_added_to_hass(self) -> None:
|
||||||
"""Handle entity added to Home Assistant."""
|
"""Handle entity added to Home Assistant."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
hass = self.hass
|
|
||||||
entry_data = self._entry_data
|
entry_data = self._entry_data
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
async_dispatcher_connect(
|
entry_data.async_subscribe_static_info_updated(self._handle_device_update)
|
||||||
hass,
|
|
||||||
entry_data.signal_static_info_updated,
|
|
||||||
self._handle_device_update,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
entry_data.async_subscribe_device_updated(self._handle_device_update)
|
entry_data.async_subscribe_device_updated(self._handle_device_update)
|
||||||
|
|
|
@ -181,7 +181,9 @@ async def mock_dashboard(hass):
|
||||||
class MockESPHomeDevice:
|
class MockESPHomeDevice:
|
||||||
"""Mock an esphome device."""
|
"""Mock an esphome device."""
|
||||||
|
|
||||||
def __init__(self, entry: MockConfigEntry, client: APIClient) -> None:
|
def __init__(
|
||||||
|
self, entry: MockConfigEntry, client: APIClient, device_info: DeviceInfo
|
||||||
|
) -> None:
|
||||||
"""Init the mock."""
|
"""Init the mock."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.client = client
|
self.client = client
|
||||||
|
@ -193,6 +195,7 @@ class MockESPHomeDevice:
|
||||||
self.home_assistant_state_subscription_callback: Callable[
|
self.home_assistant_state_subscription_callback: Callable[
|
||||||
[str, str | None], None
|
[str, str | None], None
|
||||||
]
|
]
|
||||||
|
self.device_info = device_info
|
||||||
|
|
||||||
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
|
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
|
||||||
"""Set the state callback."""
|
"""Set the state callback."""
|
||||||
|
@ -274,8 +277,6 @@ async def _mock_generic_device_entry(
|
||||||
)
|
)
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
mock_device = MockESPHomeDevice(entry, mock_client)
|
|
||||||
|
|
||||||
default_device_info = {
|
default_device_info = {
|
||||||
"name": "test",
|
"name": "test",
|
||||||
"friendly_name": "Test",
|
"friendly_name": "Test",
|
||||||
|
@ -284,6 +285,8 @@ async def _mock_generic_device_entry(
|
||||||
}
|
}
|
||||||
device_info = DeviceInfo(**(default_device_info | mock_device_info))
|
device_info = DeviceInfo(**(default_device_info | mock_device_info))
|
||||||
|
|
||||||
|
mock_device = MockESPHomeDevice(entry, mock_client, device_info)
|
||||||
|
|
||||||
def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
|
def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
|
||||||
"""Subscribe to state."""
|
"""Subscribe to state."""
|
||||||
mock_device.set_state_callback(callback)
|
mock_device.set_state_callback(callback)
|
||||||
|
@ -302,7 +305,7 @@ async def _mock_generic_device_entry(
|
||||||
"""Subscribe to home assistant states."""
|
"""Subscribe to home assistant states."""
|
||||||
mock_device.set_home_assistant_state_subscription_callback(on_state_sub)
|
mock_device.set_home_assistant_state_subscription_callback(on_state_sub)
|
||||||
|
|
||||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
mock_client.device_info = AsyncMock(return_value=mock_device.device_info)
|
||||||
mock_client.subscribe_voice_assistant = Mock()
|
mock_client.subscribe_voice_assistant = Mock()
|
||||||
mock_client.list_entities_services = AsyncMock(
|
mock_client.list_entities_services = AsyncMock(
|
||||||
return_value=mock_list_entities_services
|
return_value=mock_list_entities_services
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""Test ESPHome update entities."""
|
"""Test ESPHome update entities."""
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
import dataclasses
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from aioesphomeapi import APIClient, EntityInfo, EntityState, UserService
|
from aioesphomeapi import APIClient, EntityInfo, EntityState, UserService
|
||||||
|
@ -18,7 +17,6 @@ from homeassistant.const import (
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|
||||||
|
|
||||||
from .conftest import MockESPHomeDevice
|
from .conftest import MockESPHomeDevice
|
||||||
|
|
||||||
|
@ -176,9 +174,11 @@ async def test_update_entity(
|
||||||
|
|
||||||
async def test_update_static_info(
|
async def test_update_static_info(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
stub_reconnect,
|
mock_client: APIClient,
|
||||||
mock_config_entry,
|
mock_esphome_device: Callable[
|
||||||
mock_device_info,
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
mock_dashboard,
|
mock_dashboard,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test ESPHome update entity."""
|
"""Test ESPHome update entity."""
|
||||||
|
@ -190,32 +190,25 @@ async def test_update_static_info(
|
||||||
]
|
]
|
||||||
await async_get_dashboard(hass).async_refresh()
|
await async_get_dashboard(hass).async_refresh()
|
||||||
|
|
||||||
signal_static_info_updated = f"esphome_{mock_config_entry.entry_id}_on_list"
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
runtime_data = Mock(
|
mock_client=mock_client,
|
||||||
available=True,
|
entity_info=[],
|
||||||
device_info=mock_device_info,
|
user_service=[],
|
||||||
signal_static_info_updated=signal_static_info_updated,
|
states=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
state = hass.states.get("update.test_firmware")
|
||||||
"homeassistant.components.esphome.update.DomainData.get_entry_data",
|
|
||||||
return_value=runtime_data,
|
|
||||||
):
|
|
||||||
assert await hass.config_entries.async_forward_entry_setup(
|
|
||||||
mock_config_entry, "update"
|
|
||||||
)
|
|
||||||
|
|
||||||
state = hass.states.get("update.none_firmware")
|
|
||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.state == "on"
|
assert state.state == STATE_ON
|
||||||
|
|
||||||
runtime_data.device_info = dataclasses.replace(
|
object.__setattr__(mock_device.device_info, "esphome_version", "1.2.3")
|
||||||
runtime_data.device_info, esphome_version="1.2.3"
|
await mock_device.mock_disconnect(True)
|
||||||
)
|
await mock_device.mock_connect()
|
||||||
async_dispatcher_send(hass, signal_static_info_updated, [])
|
|
||||||
|
|
||||||
state = hass.states.get("update.none_firmware")
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert state.state == "off"
|
|
||||||
|
state = hass.states.get("update.test_firmware")
|
||||||
|
assert state.state == STATE_OFF
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Add table
Reference in a new issue