Refactor ESPHome manager to avoid sending signals in tests (#116033)

This commit is contained in:
J. Nick Koston 2024-04-24 16:14:44 +02:00 committed by GitHub
parent 220dc1f125
commit d0f5e40b19
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 48 additions and 46 deletions

View file

@ -49,9 +49,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.util.signal_type import SignalType
from .const import DOMAIN from .const import DOMAIN
from .dashboard import async_get_dashboard from .dashboard import async_get_dashboard
@ -126,6 +124,9 @@ class RuntimeEntryData:
default_factory=dict default_factory=dict
) )
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set) device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
static_info_update_subscriptions: set[Callable[[list[EntityInfo]], None]] = field(
default_factory=set
)
loaded_platforms: set[Platform] = field(default_factory=set) loaded_platforms: set[Platform] = field(default_factory=set)
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_storage_contents: StoreData | None = None _storage_contents: StoreData | None = None
@ -154,11 +155,6 @@ class RuntimeEntryData:
"_", " " "_", " "
) )
@property
def signal_static_info_updated(self) -> SignalType[list[EntityInfo]]:
"""Return the signal to listen to for updates on static info."""
return SignalType(f"esphome_{self.entry_id}_on_list")
@callback @callback
def async_register_static_info_callback( def async_register_static_info_callback(
self, self,
@ -303,8 +299,9 @@ class RuntimeEntryData:
for callback_ in callbacks_: for callback_ in callbacks_:
callback_(entity_infos) callback_(entity_infos)
# Then send dispatcher event # Finally update static info subscriptions
async_dispatcher_send(hass, self.signal_static_info_updated, infos) for callback_ in self.static_info_update_subscriptions:
callback_(infos)
@callback @callback
def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE: def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE:
@ -317,6 +314,21 @@ class RuntimeEntryData:
"""Unsubscribe to device updates.""" """Unsubscribe to device updates."""
self.device_update_subscriptions.remove(callback_) self.device_update_subscriptions.remove(callback_)
@callback
def async_subscribe_static_info_updated(
self, callback_: Callable[[list[EntityInfo]], None]
) -> CALLBACK_TYPE:
"""Subscribe to static info updates."""
self.static_info_update_subscriptions.add(callback_)
return partial(self._async_unsubscribe_static_info_updated, callback_)
@callback
def _async_unsubscribe_static_info_updated(
self, callback_: Callable[[list[EntityInfo]], None]
) -> None:
"""Unsubscribe to static info updates."""
self.static_info_update_subscriptions.remove(callback_)
@callback @callback
def async_subscribe_state_update( def async_subscribe_state_update(
self, self,

View file

@ -17,7 +17,6 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
@ -149,14 +148,9 @@ class ESPHomeUpdateEntity(CoordinatorEntity[ESPHomeDashboard], UpdateEntity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Handle entity added to Home Assistant.""" """Handle entity added to Home Assistant."""
await super().async_added_to_hass() await super().async_added_to_hass()
hass = self.hass
entry_data = self._entry_data entry_data = self._entry_data
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect( entry_data.async_subscribe_static_info_updated(self._handle_device_update)
hass,
entry_data.signal_static_info_updated,
self._handle_device_update,
)
) )
self.async_on_remove( self.async_on_remove(
entry_data.async_subscribe_device_updated(self._handle_device_update) entry_data.async_subscribe_device_updated(self._handle_device_update)

View file

@ -181,7 +181,9 @@ async def mock_dashboard(hass):
class MockESPHomeDevice: class MockESPHomeDevice:
"""Mock an esphome device.""" """Mock an esphome device."""
def __init__(self, entry: MockConfigEntry, client: APIClient) -> None: def __init__(
self, entry: MockConfigEntry, client: APIClient, device_info: DeviceInfo
) -> None:
"""Init the mock.""" """Init the mock."""
self.entry = entry self.entry = entry
self.client = client self.client = client
@ -193,6 +195,7 @@ class MockESPHomeDevice:
self.home_assistant_state_subscription_callback: Callable[ self.home_assistant_state_subscription_callback: Callable[
[str, str | None], None [str, str | None], None
] ]
self.device_info = device_info
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None: def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
"""Set the state callback.""" """Set the state callback."""
@ -274,8 +277,6 @@ async def _mock_generic_device_entry(
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
mock_device = MockESPHomeDevice(entry, mock_client)
default_device_info = { default_device_info = {
"name": "test", "name": "test",
"friendly_name": "Test", "friendly_name": "Test",
@ -284,6 +285,8 @@ async def _mock_generic_device_entry(
} }
device_info = DeviceInfo(**(default_device_info | mock_device_info)) device_info = DeviceInfo(**(default_device_info | mock_device_info))
mock_device = MockESPHomeDevice(entry, mock_client, device_info)
def _subscribe_states(callback: Callable[[EntityState], None]) -> None: def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
"""Subscribe to state.""" """Subscribe to state."""
mock_device.set_state_callback(callback) mock_device.set_state_callback(callback)
@ -302,7 +305,7 @@ async def _mock_generic_device_entry(
"""Subscribe to home assistant states.""" """Subscribe to home assistant states."""
mock_device.set_home_assistant_state_subscription_callback(on_state_sub) mock_device.set_home_assistant_state_subscription_callback(on_state_sub)
mock_client.device_info = AsyncMock(return_value=device_info) mock_client.device_info = AsyncMock(return_value=mock_device.device_info)
mock_client.subscribe_voice_assistant = Mock() mock_client.subscribe_voice_assistant = Mock()
mock_client.list_entities_services = AsyncMock( mock_client.list_entities_services = AsyncMock(
return_value=mock_list_entities_services return_value=mock_list_entities_services

View file

@ -1,7 +1,6 @@
"""Test ESPHome update entities.""" """Test ESPHome update entities."""
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
import dataclasses
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from aioesphomeapi import APIClient, EntityInfo, EntityState, UserService from aioesphomeapi import APIClient, EntityInfo, EntityState, UserService
@ -18,7 +17,6 @@ from homeassistant.const import (
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_send
from .conftest import MockESPHomeDevice from .conftest import MockESPHomeDevice
@ -176,9 +174,11 @@ async def test_update_entity(
async def test_update_static_info( async def test_update_static_info(
hass: HomeAssistant, hass: HomeAssistant,
stub_reconnect, mock_client: APIClient,
mock_config_entry, mock_esphome_device: Callable[
mock_device_info, [APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_dashboard, mock_dashboard,
) -> None: ) -> None:
"""Test ESPHome update entity.""" """Test ESPHome update entity."""
@ -190,32 +190,25 @@ async def test_update_static_info(
] ]
await async_get_dashboard(hass).async_refresh() await async_get_dashboard(hass).async_refresh()
signal_static_info_updated = f"esphome_{mock_config_entry.entry_id}_on_list" mock_device: MockESPHomeDevice = await mock_esphome_device(
runtime_data = Mock( mock_client=mock_client,
available=True, entity_info=[],
device_info=mock_device_info, user_service=[],
signal_static_info_updated=signal_static_info_updated, states=[],
) )
with patch( state = hass.states.get("update.test_firmware")
"homeassistant.components.esphome.update.DomainData.get_entry_data",
return_value=runtime_data,
):
assert await hass.config_entries.async_forward_entry_setup(
mock_config_entry, "update"
)
state = hass.states.get("update.none_firmware")
assert state is not None assert state is not None
assert state.state == "on" assert state.state == STATE_ON
runtime_data.device_info = dataclasses.replace( object.__setattr__(mock_device.device_info, "esphome_version", "1.2.3")
runtime_data.device_info, esphome_version="1.2.3" await mock_device.mock_disconnect(True)
) await mock_device.mock_connect()
async_dispatcher_send(hass, signal_static_info_updated, [])
state = hass.states.get("update.none_firmware") await hass.async_block_till_done(wait_background_tasks=True)
assert state.state == "off"
state = hass.states.get("update.test_firmware")
assert state.state == STATE_OFF
@pytest.mark.parametrize( @pytest.mark.parametrize(