From d0f5e40b197c41e66a0d9b457bb5714d11c02ced Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 24 Apr 2024 16:14:44 +0200 Subject: [PATCH] Refactor ESPHome manager to avoid sending signals in tests (#116033) --- .../components/esphome/entry_data.py | 30 +++++++++---- homeassistant/components/esphome/update.py | 8 +--- tests/components/esphome/conftest.py | 11 +++-- tests/components/esphome/test_update.py | 45 ++++++++----------- 4 files changed, 48 insertions(+), 46 deletions(-) diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 7316c09cc5e..41b18c9b88c 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -49,9 +49,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import entity_registry as er -from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.storage import Store -from homeassistant.util.signal_type import SignalType from .const import DOMAIN from .dashboard import async_get_dashboard @@ -126,6 +124,9 @@ class RuntimeEntryData: default_factory=dict ) device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set) + static_info_update_subscriptions: set[Callable[[list[EntityInfo]], None]] = field( + default_factory=set + ) loaded_platforms: set[Platform] = field(default_factory=set) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _storage_contents: StoreData | None = None @@ -154,11 +155,6 @@ class RuntimeEntryData: "_", " " ) - @property - def signal_static_info_updated(self) -> SignalType[list[EntityInfo]]: - """Return the signal to listen to for updates on static info.""" - return SignalType(f"esphome_{self.entry_id}_on_list") - @callback def async_register_static_info_callback( self, @@ -303,8 +299,9 @@ class RuntimeEntryData: for callback_ in callbacks_: callback_(entity_infos) - # Then send dispatcher event - async_dispatcher_send(hass, self.signal_static_info_updated, infos) + # Finally update static info subscriptions + for callback_ in self.static_info_update_subscriptions: + callback_(infos) @callback def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE: @@ -317,6 +314,21 @@ class RuntimeEntryData: """Unsubscribe to device updates.""" self.device_update_subscriptions.remove(callback_) + @callback + def async_subscribe_static_info_updated( + self, callback_: Callable[[list[EntityInfo]], None] + ) -> CALLBACK_TYPE: + """Subscribe to static info updates.""" + self.static_info_update_subscriptions.add(callback_) + return partial(self._async_unsubscribe_static_info_updated, callback_) + + @callback + def _async_unsubscribe_static_info_updated( + self, callback_: Callable[[list[EntityInfo]], None] + ) -> None: + """Unsubscribe to static info updates.""" + self.static_info_update_subscriptions.remove(callback_) + @callback def async_subscribe_state_update( self, diff --git a/homeassistant/components/esphome/update.py b/homeassistant/components/esphome/update.py index 3e5a82bbd0b..b16a6e798b7 100644 --- a/homeassistant/components/esphome/update.py +++ b/homeassistant/components/esphome/update.py @@ -17,7 +17,6 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr from homeassistant.helpers.device_registry import DeviceInfo -from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity @@ -149,14 +148,9 @@ class ESPHomeUpdateEntity(CoordinatorEntity[ESPHomeDashboard], UpdateEntity): async def async_added_to_hass(self) -> None: """Handle entity added to Home Assistant.""" await super().async_added_to_hass() - hass = self.hass entry_data = self._entry_data self.async_on_remove( - async_dispatcher_connect( - hass, - entry_data.signal_static_info_updated, - self._handle_device_update, - ) + entry_data.async_subscribe_static_info_updated(self._handle_device_update) ) self.async_on_remove( entry_data.async_subscribe_device_updated(self._handle_device_update) diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index e23f020991d..f71b4196be6 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -181,7 +181,9 @@ async def mock_dashboard(hass): class MockESPHomeDevice: """Mock an esphome device.""" - def __init__(self, entry: MockConfigEntry, client: APIClient) -> None: + def __init__( + self, entry: MockConfigEntry, client: APIClient, device_info: DeviceInfo + ) -> None: """Init the mock.""" self.entry = entry self.client = client @@ -193,6 +195,7 @@ class MockESPHomeDevice: self.home_assistant_state_subscription_callback: Callable[ [str, str | None], None ] + self.device_info = device_info def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None: """Set the state callback.""" @@ -274,8 +277,6 @@ async def _mock_generic_device_entry( ) entry.add_to_hass(hass) - mock_device = MockESPHomeDevice(entry, mock_client) - default_device_info = { "name": "test", "friendly_name": "Test", @@ -284,6 +285,8 @@ async def _mock_generic_device_entry( } device_info = DeviceInfo(**(default_device_info | mock_device_info)) + mock_device = MockESPHomeDevice(entry, mock_client, device_info) + def _subscribe_states(callback: Callable[[EntityState], None]) -> None: """Subscribe to state.""" mock_device.set_state_callback(callback) @@ -302,7 +305,7 @@ async def _mock_generic_device_entry( """Subscribe to home assistant states.""" mock_device.set_home_assistant_state_subscription_callback(on_state_sub) - mock_client.device_info = AsyncMock(return_value=device_info) + mock_client.device_info = AsyncMock(return_value=mock_device.device_info) mock_client.subscribe_voice_assistant = Mock() mock_client.list_entities_services = AsyncMock( return_value=mock_list_entities_services diff --git a/tests/components/esphome/test_update.py b/tests/components/esphome/test_update.py index 959ad12876d..b3deb2f33ee 100644 --- a/tests/components/esphome/test_update.py +++ b/tests/components/esphome/test_update.py @@ -1,7 +1,6 @@ """Test ESPHome update entities.""" from collections.abc import Awaitable, Callable -import dataclasses from unittest.mock import Mock, patch from aioesphomeapi import APIClient, EntityInfo, EntityState, UserService @@ -18,7 +17,6 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.dispatcher import async_dispatcher_send from .conftest import MockESPHomeDevice @@ -176,9 +174,11 @@ async def test_update_entity( async def test_update_static_info( hass: HomeAssistant, - stub_reconnect, - mock_config_entry, - mock_device_info, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], mock_dashboard, ) -> None: """Test ESPHome update entity.""" @@ -190,32 +190,25 @@ async def test_update_static_info( ] await async_get_dashboard(hass).async_refresh() - signal_static_info_updated = f"esphome_{mock_config_entry.entry_id}_on_list" - runtime_data = Mock( - available=True, - device_info=mock_device_info, - signal_static_info_updated=signal_static_info_updated, + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], ) - with patch( - "homeassistant.components.esphome.update.DomainData.get_entry_data", - return_value=runtime_data, - ): - assert await hass.config_entries.async_forward_entry_setup( - mock_config_entry, "update" - ) - - state = hass.states.get("update.none_firmware") + state = hass.states.get("update.test_firmware") assert state is not None - assert state.state == "on" + assert state.state == STATE_ON - runtime_data.device_info = dataclasses.replace( - runtime_data.device_info, esphome_version="1.2.3" - ) - async_dispatcher_send(hass, signal_static_info_updated, []) + object.__setattr__(mock_device.device_info, "esphome_version", "1.2.3") + await mock_device.mock_disconnect(True) + await mock_device.mock_connect() - state = hass.states.get("update.none_firmware") - assert state.state == "off" + await hass.async_block_till_done(wait_background_tasks=True) + + state = hass.states.get("update.test_firmware") + assert state.state == STATE_OFF @pytest.mark.parametrize(