Incorporate assist satellite entity feedback (#124727)

* Incorporate feedback

* Raise value error

* Clean up entity description

* More cleanup

* Move some things around

* Add a basic test

* Whatever

* Update CODEOWNERS

* Add tests

* Test tts response finished

* Fix test

* Wrong place

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-08-28 18:03:48 -05:00 committed by GitHub
parent 644427ecc7
commit a51de1df3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 254 additions and 45 deletions

View file

@ -144,6 +144,7 @@ build.json @home-assistant/supervisor
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @synesthesiam
/tests/components/assist_satellite/ @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL

View file

@ -9,13 +9,14 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .entity import AssistSatelliteEntity
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import AssistSatelliteState
__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
]
_LOGGER = logging.getLogger(__name__)

View file

@ -1,5 +1,6 @@
"""Assist satellite entity."""
from abc import abstractmethod
from collections.abc import AsyncIterable
import time
from typing import Final
@ -15,7 +16,6 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.const import EntityCategory
from homeassistant.core import Context
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
@ -26,18 +26,16 @@ from .models import AssistSatelliteState
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes binary sensor entities."""
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
entity_description = EntityDescription(
key="assist_satellite",
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
_attr_has_entity_name = True
_attr_name = None
entity_description: AssistSatelliteEntityDescription
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
_attr_state: AssistSatelliteState | None = None
_conversation_id: str | None = None
_conversation_id_time: float | None = None
@ -58,24 +56,27 @@ class AssistSatelliteEntity(entity.Entity):
vad_sensitivity = vad.VadSensitivity.DEFAULT
if pipeline_entity_id:
# Resolve pipeline by name
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
if (pipeline_entity_state is not None) and (
pipeline_entity_state.state != OPTION_PREFERRED
):
if (
pipeline_entity_state := self.hass.states.get(pipeline_entity_id)
) is None:
raise ValueError("Pipeline entity not found")
if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
pipeline_id = pipeline.id
break
if vad_sensitivity_entity_id:
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
if vad_sensitivity_state is not None:
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise ValueError("VAD sensitivity entity not found")
device_id: str | None = None
if self.registry_entry is not None:
device_id = self.registry_entry.device_id
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
device_id = self.registry_entry.device_id if self.registry_entry else None
# Refresh context if necessary
if (
@ -105,7 +106,7 @@ class AssistSatelliteEntity(entity.Entity):
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self.on_pipeline_event,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
@ -123,24 +124,32 @@ class AssistSatelliteEntity(entity.Entity):
audio_settings=AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
),
start_stage=start_stage,
end_stage=end_stage,
)
@abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type == PipelineEventType.WAKE_WORD_START:
if event.type is PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type == PipelineEventType.STT_START:
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type == PipelineEventType.INTENT_START:
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type == PipelineEventType.TTS_START:
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type == PipelineEventType.RUN_END:
elif event.type is PipelineEventType.RUN_END:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
self.on_pipeline_event(event)
def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state

View file

@ -1,13 +1,12 @@
{
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
"entity_component": {
"_": {
"name": "Assist satellite",
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
}

View file

@ -19,7 +19,11 @@ from homeassistant.components.assist_pipeline import (
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_satellite import AssistSatelliteEntity
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteState,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -78,6 +82,12 @@ async def async_setup_entry(
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
"""Assist satellite for VoIP devices."""
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_has_entity_name = True
_attr_name = None
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
def __init__(
self,
hass: HomeAssistant,
@ -108,8 +118,8 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
if self.voip_device.protocol == self:
self.voip_device.protocol = None
assert self.voip_device.protocol == self
self.voip_device.protocol = None
# -------------------------------------------------------------------------
# VoIP
@ -188,8 +198,6 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
super().on_pipeline_event(event)
if event.type == PipelineEventType.STT_END:
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
self._processing_tone_done.clear()

View file

@ -13,10 +13,10 @@
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
}
}
},

View file

@ -0,0 +1 @@
"""Tests for Assist Satellite."""

View file

@ -0,0 +1,104 @@
"""Test helpers for Assist Satellite."""
from unittest.mock import Mock
import pytest
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
TEST_DOMAIN = "test_satellite"
class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""
_attr_name = "Test Entity"
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)
@pytest.fixture
def entity() -> MockAssistSatellite:
"""Mock Assist Satellite Entity."""
return MockAssistSatellite()
@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Mock config entry."""
entry = MockConfigEntry(domain=TEST_DOMAIN)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_components(
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry

View file

@ -0,0 +1,86 @@
"""Test the Assist Satellite entity."""
from unittest.mock import patch
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
vad,
)
from homeassistant.components.assist_satellite import AssistSatelliteState
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import Context, HomeAssistant
from .conftest import MockAssistSatellite
ENTITY_ID = "assist_satellite.test_entity"
async def test_entity_state(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test entity state represent events."""
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == STATE_UNKNOWN
context = Context()
audio_stream = object()
entity.async_set_context(context)
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
) as mock_start_pipeline:
await entity._async_accept_pipeline_from_satellite(audio_stream)
assert mock_start_pipeline.called
kwargs = mock_start_pipeline.call_args[1]
assert kwargs["context"] is context
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav"
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
)
assert kwargs["start_stage"] == PipelineStage.STT
assert kwargs["end_stage"] == PipelineStage.TTS
for event_type, expected_state in (
(PipelineEventType.RUN_START, STATE_UNKNOWN),
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
):
kwargs["event_callback"](PipelineEvent(event_type, {}))
state = hass.states.get(ENTITY_ID)
assert state.state == expected_state, event_type
entity.tts_response_finished()
state = hass.states.get(ENTITY_ID)
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD