Fix race and add test coverage for esphome select platform (#95019)
This commit is contained in:
parent
ef2669afe4
commit
65a5244d5a
5 changed files with 148 additions and 47 deletions
|
@ -319,7 +319,6 @@ omit =
|
|||
homeassistant/components/esphome/lock.py
|
||||
homeassistant/components/esphome/media_player.py
|
||||
homeassistant/components/esphome/number.py
|
||||
homeassistant/components/esphome/select.py
|
||||
homeassistant/components/esphome/sensor.py
|
||||
homeassistant/components/esphome/switch.py
|
||||
homeassistant/components/etherscan/sensor.py
|
||||
|
|
|
@ -845,6 +845,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
self._on_static_info_update,
|
||||
)
|
||||
)
|
||||
self._update_state_from_entry_data()
|
||||
|
||||
@callback
|
||||
def _on_static_info_update(self, static_info: EntityInfo) -> None:
|
||||
|
@ -868,11 +869,9 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
self._attr_icon = None
|
||||
|
||||
@callback
|
||||
def _on_state_update(self) -> None:
|
||||
"""Call when state changed.
|
||||
def _update_state_from_entry_data(self) -> None:
|
||||
"""Update state from entry data."""
|
||||
|
||||
Behavior can be changed in child classes
|
||||
"""
|
||||
state = self._entry_data.state
|
||||
key = self._key
|
||||
state_type = self._state_type
|
||||
|
@ -880,6 +879,14 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
|||
if has_state:
|
||||
self._state = cast(_StateT, state[state_type][key])
|
||||
self._has_state = has_state
|
||||
|
||||
@callback
|
||||
def _on_state_update(self) -> None:
|
||||
"""Call when state changed.
|
||||
|
||||
Behavior can be changed in child classes
|
||||
"""
|
||||
self._update_state_from_entry_data()
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
|
|
|
@ -53,9 +53,8 @@ class EsphomeSelect(EsphomeEntity[SelectInfo, SelectState], SelectEntity):
|
|||
@esphome_state_property
|
||||
def current_option(self) -> str | None:
|
||||
"""Return the state of the entity."""
|
||||
if self._state.missing_state:
|
||||
return None
|
||||
return self._state.state
|
||||
state = self._state
|
||||
return None if state.missing_state else state.state
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Change the selected option."""
|
||||
|
|
|
@ -2,9 +2,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from asyncio import Event
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from aioesphomeapi import APIClient, APIVersion, DeviceInfo, ReconnectLogic
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
APIVersion,
|
||||
DeviceInfo,
|
||||
EntityInfo,
|
||||
EntityState,
|
||||
ReconnectLogic,
|
||||
UserService,
|
||||
)
|
||||
import pytest
|
||||
from zeroconf import Zeroconf
|
||||
|
||||
|
@ -82,7 +92,7 @@ async def init_integration(
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(mock_device_info):
|
||||
def mock_client(mock_device_info) -> APIClient:
|
||||
"""Mock APIClient."""
|
||||
mock_client = Mock(spec=APIClient)
|
||||
|
||||
|
@ -132,49 +142,72 @@ async def mock_dashboard(hass):
|
|||
yield data
|
||||
|
||||
|
||||
async def _mock_generic_device_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_device_info: dict[str, Any],
|
||||
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
|
||||
states: list[EntityState],
|
||||
) -> MockConfigEntry:
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
device_info = DeviceInfo(
|
||||
name="test",
|
||||
friendly_name="Test",
|
||||
mac_address="11:22:33:44:55:aa",
|
||||
esphome_version="1.0.0",
|
||||
**mock_device_info,
|
||||
)
|
||||
|
||||
async def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
|
||||
"""Subscribe to state."""
|
||||
for state in states:
|
||||
callback(state)
|
||||
|
||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
||||
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
||||
mock_client.list_entities_services = AsyncMock(
|
||||
return_value=mock_list_entities_services
|
||||
)
|
||||
mock_client.subscribe_states = _subscribe_states
|
||||
|
||||
try_connect_done = Event()
|
||||
real_try_connect = ReconnectLogic._try_connect
|
||||
|
||||
async def mock_try_connect(self):
|
||||
"""Set an event when ReconnectLogic._try_connect has been awaited."""
|
||||
result = await real_try_connect(self)
|
||||
try_connect_done.set()
|
||||
return result
|
||||
|
||||
with patch.object(ReconnectLogic, "_try_connect", mock_try_connect):
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await try_connect_done.wait()
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_voice_assistant_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_client,
|
||||
) -> MockConfigEntry:
|
||||
mock_client: APIClient,
|
||||
):
|
||||
"""Set up an ESPHome entry with voice assistant."""
|
||||
|
||||
async def _mock_voice_assistant_entry(version: int):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry:
|
||||
return await _mock_generic_device_entry(
|
||||
hass, mock_client, {"voice_assistant_version": version}, ([], []), []
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
device_info = DeviceInfo(
|
||||
name="test",
|
||||
friendly_name="Test",
|
||||
voice_assistant_version=version,
|
||||
mac_address="11:22:33:44:55:aa",
|
||||
esphome_version="1.0.0",
|
||||
)
|
||||
|
||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
||||
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
||||
|
||||
try_connect_done = Event()
|
||||
real_try_connect = ReconnectLogic._try_connect
|
||||
|
||||
async def mock_try_connect(self):
|
||||
"""Set an event when ReconnectLogic._try_connect has been awaited."""
|
||||
result = await real_try_connect(self)
|
||||
try_connect_done.set()
|
||||
return result
|
||||
|
||||
with patch.object(ReconnectLogic, "_try_connect", mock_try_connect):
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await try_connect_done.wait()
|
||||
|
||||
return entry
|
||||
|
||||
return _mock_voice_assistant_entry
|
||||
|
||||
|
@ -189,3 +222,22 @@ async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfi
|
|||
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||
"""Set up an ESPHome entry with voice assistant."""
|
||||
return await mock_voice_assistant_entry(version=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_generic_device_entry(
|
||||
hass: HomeAssistant,
|
||||
) -> MockConfigEntry:
|
||||
"""Set up an ESPHome entry."""
|
||||
|
||||
async def _mock_device_entry(
|
||||
mock_client: APIClient,
|
||||
entity_info: list[EntityInfo],
|
||||
user_service: list[UserService],
|
||||
states: list[EntityState],
|
||||
) -> MockConfigEntry:
|
||||
return await _mock_generic_device_entry(
|
||||
hass, mock_client, {}, (entity_info, user_service), states
|
||||
)
|
||||
|
||||
return _mock_device_entry
|
||||
|
|
|
@ -1,6 +1,16 @@
|
|||
"""Test ESPHome selects."""
|
||||
|
||||
|
||||
from unittest.mock import call
|
||||
|
||||
from aioesphomeapi import APIClient, SelectInfo, SelectState
|
||||
|
||||
from homeassistant.components.select import (
|
||||
ATTR_OPTION,
|
||||
DOMAIN as SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
|
@ -13,3 +23,37 @@ async def test_pipeline_selector(
|
|||
state = hass.states.get("select.test_assist_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
|
||||
|
||||
async def test_select_generic_entity(
|
||||
hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry
|
||||
) -> None:
|
||||
"""Test a generic select entity."""
|
||||
entity_info = [
|
||||
SelectInfo(
|
||||
object_id="myselect",
|
||||
key=1,
|
||||
name="my select",
|
||||
unique_id="my_select",
|
||||
options=["a", "b"],
|
||||
)
|
||||
]
|
||||
states = [SelectState(key=1, state="a")]
|
||||
user_service = []
|
||||
await mock_generic_device_entry(
|
||||
mock_client=mock_client,
|
||||
entity_info=entity_info,
|
||||
user_service=user_service,
|
||||
states=states,
|
||||
)
|
||||
state = hass.states.get("select.test_my_select")
|
||||
assert state is not None
|
||||
assert state.state == "a"
|
||||
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: "select.test_my_select", ATTR_OPTION: "b"},
|
||||
blocking=True,
|
||||
)
|
||||
mock_client.select_command.assert_has_calls([call(1, "b")])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue