Fix race and add test coverage for esphome select platform (#95019)
This commit is contained in:
parent
ef2669afe4
commit
65a5244d5a
5 changed files with 148 additions and 47 deletions
|
@ -319,7 +319,6 @@ omit =
|
||||||
homeassistant/components/esphome/lock.py
|
homeassistant/components/esphome/lock.py
|
||||||
homeassistant/components/esphome/media_player.py
|
homeassistant/components/esphome/media_player.py
|
||||||
homeassistant/components/esphome/number.py
|
homeassistant/components/esphome/number.py
|
||||||
homeassistant/components/esphome/select.py
|
|
||||||
homeassistant/components/esphome/sensor.py
|
homeassistant/components/esphome/sensor.py
|
||||||
homeassistant/components/esphome/switch.py
|
homeassistant/components/esphome/switch.py
|
||||||
homeassistant/components/etherscan/sensor.py
|
homeassistant/components/etherscan/sensor.py
|
||||||
|
|
|
@ -845,6 +845,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
||||||
self._on_static_info_update,
|
self._on_static_info_update,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self._update_state_from_entry_data()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _on_static_info_update(self, static_info: EntityInfo) -> None:
|
def _on_static_info_update(self, static_info: EntityInfo) -> None:
|
||||||
|
@ -868,11 +869,9 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
||||||
self._attr_icon = None
|
self._attr_icon = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _on_state_update(self) -> None:
|
def _update_state_from_entry_data(self) -> None:
|
||||||
"""Call when state changed.
|
"""Update state from entry data."""
|
||||||
|
|
||||||
Behavior can be changed in child classes
|
|
||||||
"""
|
|
||||||
state = self._entry_data.state
|
state = self._entry_data.state
|
||||||
key = self._key
|
key = self._key
|
||||||
state_type = self._state_type
|
state_type = self._state_type
|
||||||
|
@ -880,6 +879,14 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
|
||||||
if has_state:
|
if has_state:
|
||||||
self._state = cast(_StateT, state[state_type][key])
|
self._state = cast(_StateT, state[state_type][key])
|
||||||
self._has_state = has_state
|
self._has_state = has_state
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _on_state_update(self) -> None:
|
||||||
|
"""Call when state changed.
|
||||||
|
|
||||||
|
Behavior can be changed in child classes
|
||||||
|
"""
|
||||||
|
self._update_state_from_entry_data()
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -53,9 +53,8 @@ class EsphomeSelect(EsphomeEntity[SelectInfo, SelectState], SelectEntity):
|
||||||
@esphome_state_property
|
@esphome_state_property
|
||||||
def current_option(self) -> str | None:
|
def current_option(self) -> str | None:
|
||||||
"""Return the state of the entity."""
|
"""Return the state of the entity."""
|
||||||
if self._state.missing_state:
|
state = self._state
|
||||||
return None
|
return None if state.missing_state else state.state
|
||||||
return self._state.state
|
|
||||||
|
|
||||||
async def async_select_option(self, option: str) -> None:
|
async def async_select_option(self, option: str) -> None:
|
||||||
"""Change the selected option."""
|
"""Change the selected option."""
|
||||||
|
|
|
@ -2,9 +2,19 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from asyncio import Event
|
from asyncio import Event
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from aioesphomeapi import APIClient, APIVersion, DeviceInfo, ReconnectLogic
|
from aioesphomeapi import (
|
||||||
|
APIClient,
|
||||||
|
APIVersion,
|
||||||
|
DeviceInfo,
|
||||||
|
EntityInfo,
|
||||||
|
EntityState,
|
||||||
|
ReconnectLogic,
|
||||||
|
UserService,
|
||||||
|
)
|
||||||
import pytest
|
import pytest
|
||||||
from zeroconf import Zeroconf
|
from zeroconf import Zeroconf
|
||||||
|
|
||||||
|
@ -82,7 +92,7 @@ async def init_integration(
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_client(mock_device_info):
|
def mock_client(mock_device_info) -> APIClient:
|
||||||
"""Mock APIClient."""
|
"""Mock APIClient."""
|
||||||
mock_client = Mock(spec=APIClient)
|
mock_client = Mock(spec=APIClient)
|
||||||
|
|
||||||
|
@ -132,14 +142,13 @@ async def mock_dashboard(hass):
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
async def _mock_generic_device_entry(
|
||||||
async def mock_voice_assistant_entry(
|
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_client,
|
mock_client: APIClient,
|
||||||
|
mock_device_info: dict[str, Any],
|
||||||
|
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
|
||||||
|
states: list[EntityState],
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Set up an ESPHome entry with voice assistant."""
|
|
||||||
|
|
||||||
async def _mock_voice_assistant_entry(version: int):
|
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={
|
data={
|
||||||
|
@ -153,13 +162,22 @@ async def mock_voice_assistant_entry(
|
||||||
device_info = DeviceInfo(
|
device_info = DeviceInfo(
|
||||||
name="test",
|
name="test",
|
||||||
friendly_name="Test",
|
friendly_name="Test",
|
||||||
voice_assistant_version=version,
|
|
||||||
mac_address="11:22:33:44:55:aa",
|
mac_address="11:22:33:44:55:aa",
|
||||||
esphome_version="1.0.0",
|
esphome_version="1.0.0",
|
||||||
|
**mock_device_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _subscribe_states(callback: Callable[[EntityState], None]) -> None:
|
||||||
|
"""Subscribe to state."""
|
||||||
|
for state in states:
|
||||||
|
callback(state)
|
||||||
|
|
||||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
mock_client.device_info = AsyncMock(return_value=device_info)
|
||||||
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
||||||
|
mock_client.list_entities_services = AsyncMock(
|
||||||
|
return_value=mock_list_entities_services
|
||||||
|
)
|
||||||
|
mock_client.subscribe_states = _subscribe_states
|
||||||
|
|
||||||
try_connect_done = Event()
|
try_connect_done = Event()
|
||||||
real_try_connect = ReconnectLogic._try_connect
|
real_try_connect = ReconnectLogic._try_connect
|
||||||
|
@ -174,8 +192,23 @@ async def mock_voice_assistant_entry(
|
||||||
await hass.config_entries.async_setup(entry.entry_id)
|
await hass.config_entries.async_setup(entry.entry_id)
|
||||||
await try_connect_done.wait()
|
await try_connect_done.wait()
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_voice_assistant_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
):
|
||||||
|
"""Set up an ESPHome entry with voice assistant."""
|
||||||
|
|
||||||
|
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry:
|
||||||
|
return await _mock_generic_device_entry(
|
||||||
|
hass, mock_client, {"voice_assistant_version": version}, ([], []), []
|
||||||
|
)
|
||||||
|
|
||||||
return _mock_voice_assistant_entry
|
return _mock_voice_assistant_entry
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,3 +222,22 @@ async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfi
|
||||||
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||||
"""Set up an ESPHome entry with voice assistant."""
|
"""Set up an ESPHome entry with voice assistant."""
|
||||||
return await mock_voice_assistant_entry(version=2)
|
return await mock_voice_assistant_entry(version=2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_generic_device_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> MockConfigEntry:
|
||||||
|
"""Set up an ESPHome entry."""
|
||||||
|
|
||||||
|
async def _mock_device_entry(
|
||||||
|
mock_client: APIClient,
|
||||||
|
entity_info: list[EntityInfo],
|
||||||
|
user_service: list[UserService],
|
||||||
|
states: list[EntityState],
|
||||||
|
) -> MockConfigEntry:
|
||||||
|
return await _mock_generic_device_entry(
|
||||||
|
hass, mock_client, {}, (entity_info, user_service), states
|
||||||
|
)
|
||||||
|
|
||||||
|
return _mock_device_entry
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
"""Test ESPHome selects."""
|
"""Test ESPHome selects."""
|
||||||
|
|
||||||
|
|
||||||
|
from unittest.mock import call
|
||||||
|
|
||||||
|
from aioesphomeapi import APIClient, SelectInfo, SelectState
|
||||||
|
|
||||||
|
from homeassistant.components.select import (
|
||||||
|
ATTR_OPTION,
|
||||||
|
DOMAIN as SELECT_DOMAIN,
|
||||||
|
SERVICE_SELECT_OPTION,
|
||||||
|
)
|
||||||
|
from homeassistant.const import ATTR_ENTITY_ID
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,3 +23,37 @@ async def test_pipeline_selector(
|
||||||
state = hass.states.get("select.test_assist_pipeline")
|
state = hass.states.get("select.test_assist_pipeline")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.state == "preferred"
|
assert state.state == "preferred"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_select_generic_entity(
|
||||||
|
hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry
|
||||||
|
) -> None:
|
||||||
|
"""Test a generic select entity."""
|
||||||
|
entity_info = [
|
||||||
|
SelectInfo(
|
||||||
|
object_id="myselect",
|
||||||
|
key=1,
|
||||||
|
name="my select",
|
||||||
|
unique_id="my_select",
|
||||||
|
options=["a", "b"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
states = [SelectState(key=1, state="a")]
|
||||||
|
user_service = []
|
||||||
|
await mock_generic_device_entry(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=entity_info,
|
||||||
|
user_service=user_service,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
state = hass.states.get("select.test_my_select")
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == "a"
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
SELECT_DOMAIN,
|
||||||
|
SERVICE_SELECT_OPTION,
|
||||||
|
{ATTR_ENTITY_ID: "select.test_my_select", ATTR_OPTION: "b"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
mock_client.select_command.assert_has_calls([call(1, "b")])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue