Fix race and add test coverage for esphome select platform (#95019)

This commit is contained in:
J. Nick Koston 2023-06-22 01:19:47 +02:00 committed by GitHub
parent ef2669afe4
commit 65a5244d5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 47 deletions

View file

@ -319,7 +319,6 @@ omit =
homeassistant/components/esphome/lock.py homeassistant/components/esphome/lock.py
homeassistant/components/esphome/media_player.py homeassistant/components/esphome/media_player.py
homeassistant/components/esphome/number.py homeassistant/components/esphome/number.py
homeassistant/components/esphome/select.py
homeassistant/components/esphome/sensor.py homeassistant/components/esphome/sensor.py
homeassistant/components/esphome/switch.py homeassistant/components/esphome/switch.py
homeassistant/components/etherscan/sensor.py homeassistant/components/etherscan/sensor.py

View file

@ -845,6 +845,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
self._on_static_info_update, self._on_static_info_update,
) )
) )
self._update_state_from_entry_data()
@callback @callback
def _on_static_info_update(self, static_info: EntityInfo) -> None: def _on_static_info_update(self, static_info: EntityInfo) -> None:
@ -868,11 +869,9 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
self._attr_icon = None self._attr_icon = None
@callback @callback
def _on_state_update(self) -> None: def _update_state_from_entry_data(self) -> None:
"""Call when state changed. """Update state from entry data."""
Behavior can be changed in child classes
"""
state = self._entry_data.state state = self._entry_data.state
key = self._key key = self._key
state_type = self._state_type state_type = self._state_type
@ -880,6 +879,14 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
if has_state: if has_state:
self._state = cast(_StateT, state[state_type][key]) self._state = cast(_StateT, state[state_type][key])
self._has_state = has_state self._has_state = has_state
@callback
def _on_state_update(self) -> None:
"""Call when state changed.
Behavior can be changed in child classes
"""
self._update_state_from_entry_data()
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback

View file

@ -53,9 +53,8 @@ class EsphomeSelect(EsphomeEntity[SelectInfo, SelectState], SelectEntity):
@esphome_state_property @esphome_state_property
def current_option(self) -> str | None: def current_option(self) -> str | None:
"""Return the state of the entity.""" """Return the state of the entity."""
if self._state.missing_state: state = self._state
return None return None if state.missing_state else state.state
return self._state.state
async def async_select_option(self, option: str) -> None: async def async_select_option(self, option: str) -> None:
"""Change the selected option.""" """Change the selected option."""

View file

@ -2,9 +2,19 @@
from __future__ import annotations from __future__ import annotations
from asyncio import Event from asyncio import Event
from collections.abc import Callable
from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from aioesphomeapi import APIClient, APIVersion, DeviceInfo, ReconnectLogic from aioesphomeapi import (
APIClient,
APIVersion,
DeviceInfo,
EntityInfo,
EntityState,
ReconnectLogic,
UserService,
)
import pytest import pytest
from zeroconf import Zeroconf from zeroconf import Zeroconf
@ -82,7 +92,7 @@ async def init_integration(
@pytest.fixture @pytest.fixture
def mock_client(mock_device_info): def mock_client(mock_device_info) -> APIClient:
"""Mock APIClient.""" """Mock APIClient."""
mock_client = Mock(spec=APIClient) mock_client = Mock(spec=APIClient)
@ -132,14 +142,13 @@ async def mock_dashboard(hass):
yield data yield data
@pytest.fixture async def _mock_generic_device_entry(
async def mock_voice_assistant_entry(
hass: HomeAssistant, hass: HomeAssistant,
mock_client, mock_client: APIClient,
mock_device_info: dict[str, Any],
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
states: list[EntityState],
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant."""
async def _mock_voice_assistant_entry(version: int):
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={ data={
@ -153,13 +162,22 @@ async def mock_voice_assistant_entry(
device_info = DeviceInfo( device_info = DeviceInfo(
name="test", name="test",
friendly_name="Test", friendly_name="Test",
voice_assistant_version=version,
mac_address="11:22:33:44:55:aa", mac_address="11:22:33:44:55:aa",
esphome_version="1.0.0", esphome_version="1.0.0",
**mock_device_info,
) )
async def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
"""Subscribe to state."""
for state in states:
callback(state)
mock_client.device_info = AsyncMock(return_value=device_info) mock_client.device_info = AsyncMock(return_value=device_info)
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
mock_client.list_entities_services = AsyncMock(
return_value=mock_list_entities_services
)
mock_client.subscribe_states = _subscribe_states
try_connect_done = Event() try_connect_done = Event()
real_try_connect = ReconnectLogic._try_connect real_try_connect = ReconnectLogic._try_connect
@ -174,8 +192,23 @@ async def mock_voice_assistant_entry(
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await try_connect_done.wait() await try_connect_done.wait()
await hass.async_block_till_done()
return entry return entry
@pytest.fixture
async def mock_voice_assistant_entry(
hass: HomeAssistant,
mock_client: APIClient,
):
"""Set up an ESPHome entry with voice assistant."""
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry:
return await _mock_generic_device_entry(
hass, mock_client, {"voice_assistant_version": version}, ([], []), []
)
return _mock_voice_assistant_entry return _mock_voice_assistant_entry
@ -189,3 +222,22 @@ async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfi
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry: async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant.""" """Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(version=2) return await mock_voice_assistant_entry(version=2)
@pytest.fixture
async def mock_generic_device_entry(
hass: HomeAssistant,
) -> MockConfigEntry:
"""Set up an ESPHome entry."""
async def _mock_device_entry(
mock_client: APIClient,
entity_info: list[EntityInfo],
user_service: list[UserService],
states: list[EntityState],
) -> MockConfigEntry:
return await _mock_generic_device_entry(
hass, mock_client, {}, (entity_info, user_service), states
)
return _mock_device_entry

View file

@ -1,6 +1,16 @@
"""Test ESPHome selects.""" """Test ESPHome selects."""
from unittest.mock import call
from aioesphomeapi import APIClient, SelectInfo, SelectState
from homeassistant.components.select import (
ATTR_OPTION,
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
)
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -13,3 +23,37 @@ async def test_pipeline_selector(
state = hass.states.get("select.test_assist_pipeline") state = hass.states.get("select.test_assist_pipeline")
assert state is not None assert state is not None
assert state.state == "preferred" assert state.state == "preferred"
async def test_select_generic_entity(
hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry
) -> None:
"""Test a generic select entity."""
entity_info = [
SelectInfo(
object_id="myselect",
key=1,
name="my select",
unique_id="my_select",
options=["a", "b"],
)
]
states = [SelectState(key=1, state="a")]
user_service = []
await mock_generic_device_entry(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
)
state = hass.states.get("select.test_my_select")
assert state is not None
assert state.state == "a"
await hass.services.async_call(
SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: "select.test_my_select", ATTR_OPTION: "b"},
blocking=True,
)
mock_client.select_command.assert_has_calls([call(1, "b")])