Fix race and add test coverage for esphome select platform (#95019)

This commit is contained in:
J. Nick Koston 2023-06-22 01:19:47 +02:00 committed by GitHub
parent ef2669afe4
commit 65a5244d5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 47 deletions

View file

@ -319,7 +319,6 @@ omit =
homeassistant/components/esphome/lock.py
homeassistant/components/esphome/media_player.py
homeassistant/components/esphome/number.py
homeassistant/components/esphome/select.py
homeassistant/components/esphome/sensor.py
homeassistant/components/esphome/switch.py
homeassistant/components/etherscan/sensor.py

View file

@ -845,6 +845,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
self._on_static_info_update,
)
)
self._update_state_from_entry_data()
@callback
def _on_static_info_update(self, static_info: EntityInfo) -> None:
@ -868,11 +869,9 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
self._attr_icon = None
@callback
def _on_state_update(self) -> None:
"""Call when state changed.
def _update_state_from_entry_data(self) -> None:
"""Update state from entry data."""
Behavior can be changed in child classes
"""
state = self._entry_data.state
key = self._key
state_type = self._state_type
@ -880,6 +879,14 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
if has_state:
self._state = cast(_StateT, state[state_type][key])
self._has_state = has_state
@callback
def _on_state_update(self) -> None:
"""Call when state changed.
Behavior can be changed in child classes
"""
self._update_state_from_entry_data()
self.async_write_ha_state()
@callback

View file

@ -53,9 +53,8 @@ class EsphomeSelect(EsphomeEntity[SelectInfo, SelectState], SelectEntity):
@esphome_state_property
def current_option(self) -> str | None:
"""Return the state of the entity."""
if self._state.missing_state:
return None
return self._state.state
state = self._state
return None if state.missing_state else state.state
async def async_select_option(self, option: str) -> None:
"""Change the selected option."""

View file

@ -2,9 +2,19 @@
from __future__ import annotations
from asyncio import Event
from collections.abc import Callable
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
from aioesphomeapi import APIClient, APIVersion, DeviceInfo, ReconnectLogic
from aioesphomeapi import (
APIClient,
APIVersion,
DeviceInfo,
EntityInfo,
EntityState,
ReconnectLogic,
UserService,
)
import pytest
from zeroconf import Zeroconf
@ -82,7 +92,7 @@ async def init_integration(
@pytest.fixture
def mock_client(mock_device_info):
def mock_client(mock_device_info) -> APIClient:
"""Mock APIClient."""
mock_client = Mock(spec=APIClient)
@ -132,49 +142,72 @@ async def mock_dashboard(hass):
yield data
async def _mock_generic_device_entry(
hass: HomeAssistant,
mock_client: APIClient,
mock_device_info: dict[str, Any],
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
states: list[EntityState],
) -> MockConfigEntry:
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_HOST: "test.local",
CONF_PORT: 6053,
CONF_PASSWORD: "",
},
)
entry.add_to_hass(hass)
device_info = DeviceInfo(
name="test",
friendly_name="Test",
mac_address="11:22:33:44:55:aa",
esphome_version="1.0.0",
**mock_device_info,
)
async def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
"""Subscribe to state."""
for state in states:
callback(state)
mock_client.device_info = AsyncMock(return_value=device_info)
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
mock_client.list_entities_services = AsyncMock(
return_value=mock_list_entities_services
)
mock_client.subscribe_states = _subscribe_states
try_connect_done = Event()
real_try_connect = ReconnectLogic._try_connect
async def mock_try_connect(self):
"""Set an event when ReconnectLogic._try_connect has been awaited."""
result = await real_try_connect(self)
try_connect_done.set()
return result
with patch.object(ReconnectLogic, "_try_connect", mock_try_connect):
await hass.config_entries.async_setup(entry.entry_id)
await try_connect_done.wait()
await hass.async_block_till_done()
return entry
@pytest.fixture
async def mock_voice_assistant_entry(
hass: HomeAssistant,
mock_client,
) -> MockConfigEntry:
mock_client: APIClient,
):
"""Set up an ESPHome entry with voice assistant."""
async def _mock_voice_assistant_entry(version: int):
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_HOST: "test.local",
CONF_PORT: 6053,
CONF_PASSWORD: "",
},
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry:
return await _mock_generic_device_entry(
hass, mock_client, {"voice_assistant_version": version}, ([], []), []
)
entry.add_to_hass(hass)
device_info = DeviceInfo(
name="test",
friendly_name="Test",
voice_assistant_version=version,
mac_address="11:22:33:44:55:aa",
esphome_version="1.0.0",
)
mock_client.device_info = AsyncMock(return_value=device_info)
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
try_connect_done = Event()
real_try_connect = ReconnectLogic._try_connect
async def mock_try_connect(self):
"""Set an event when ReconnectLogic._try_connect has been awaited."""
result = await real_try_connect(self)
try_connect_done.set()
return result
with patch.object(ReconnectLogic, "_try_connect", mock_try_connect):
await hass.config_entries.async_setup(entry.entry_id)
await try_connect_done.wait()
return entry
return _mock_voice_assistant_entry
@ -189,3 +222,22 @@ async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfi
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(version=2)
@pytest.fixture
async def mock_generic_device_entry(
hass: HomeAssistant,
) -> MockConfigEntry:
"""Set up an ESPHome entry."""
async def _mock_device_entry(
mock_client: APIClient,
entity_info: list[EntityInfo],
user_service: list[UserService],
states: list[EntityState],
) -> MockConfigEntry:
return await _mock_generic_device_entry(
hass, mock_client, {}, (entity_info, user_service), states
)
return _mock_device_entry

View file

@ -1,6 +1,16 @@
"""Test ESPHome selects."""
from unittest.mock import call
from aioesphomeapi import APIClient, SelectInfo, SelectState
from homeassistant.components.select import (
ATTR_OPTION,
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
)
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant
@ -13,3 +23,37 @@ async def test_pipeline_selector(
state = hass.states.get("select.test_assist_pipeline")
assert state is not None
assert state.state == "preferred"
async def test_select_generic_entity(
hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry
) -> None:
"""Test a generic select entity."""
entity_info = [
SelectInfo(
object_id="myselect",
key=1,
name="my select",
unique_id="my_select",
options=["a", "b"],
)
]
states = [SelectState(key=1, state="a")]
user_service = []
await mock_generic_device_entry(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
)
state = hass.states.get("select.test_my_select")
assert state is not None
assert state.state == "a"
await hass.services.async_call(
SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: "select.test_my_select", ATTR_OPTION: "b"},
blocking=True,
)
mock_client.select_command.assert_has_calls([call(1, "b")])