Incorporate assist satellite entity feedback (#124727)
* Incorporate feedback * Raise value error * Clean up entity description * More cleanup * Move some things around * Add a basic test * Whatever * Update CODEOWNERS * Add tests * Test tts response finished * Fix test * Wrong place --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
644427ecc7
commit
a51de1df3c
9 changed files with 254 additions and 45 deletions
|
@ -144,6 +144,7 @@ build.json @home-assistant/supervisor
|
|||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/homeassistant/components/assist_satellite/ @synesthesiam
|
||||
/tests/components/assist_satellite/ @synesthesiam
|
||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||
/homeassistant/components/atag/ @MatsNL
|
||||
|
|
|
@ -9,13 +9,14 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||
from .models import AssistSatelliteState
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteState",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityDescription",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Assist satellite entity."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable
|
||||
import time
|
||||
from typing import Final
|
||||
|
@ -15,7 +16,6 @@ from homeassistant.components.assist_pipeline import (
|
|||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.const import EntityCategory
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
|
@ -26,18 +26,16 @@ from .models import AssistSatelliteState
|
|||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
|
||||
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||
"""A class that describes binary sensor entities."""
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
entity_description = EntityDescription(
|
||||
key="assist_satellite",
|
||||
translation_key="assist_satellite",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
entity_description: AssistSatelliteEntityDescription
|
||||
_attr_should_poll = False
|
||||
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
_attr_state: AssistSatelliteState | None = None
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
@ -58,24 +56,27 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||
|
||||
if pipeline_entity_id:
|
||||
# Resolve pipeline by name
|
||||
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
|
||||
if (pipeline_entity_state is not None) and (
|
||||
pipeline_entity_state.state != OPTION_PREFERRED
|
||||
):
|
||||
if (
|
||||
pipeline_entity_state := self.hass.states.get(pipeline_entity_id)
|
||||
) is None:
|
||||
raise ValueError("Pipeline entity not found")
|
||||
|
||||
if pipeline_entity_state.state != OPTION_PREFERRED:
|
||||
# Resolve pipeline by name
|
||||
for pipeline in async_get_pipelines(self.hass):
|
||||
if pipeline.name == pipeline_entity_state.state:
|
||||
pipeline_id = pipeline.id
|
||||
break
|
||||
|
||||
if vad_sensitivity_entity_id:
|
||||
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
|
||||
if vad_sensitivity_state is not None:
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
if (
|
||||
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
||||
) is None:
|
||||
raise ValueError("VAD sensitivity entity not found")
|
||||
|
||||
device_id: str | None = None
|
||||
if self.registry_entry is not None:
|
||||
device_id = self.registry_entry.device_id
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||
|
||||
# Refresh context if necessary
|
||||
if (
|
||||
|
@ -105,7 +106,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self.on_pipeline_event,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
|
@ -123,24 +124,32 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
audio_settings=AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
||||
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type == PipelineEventType.WAKE_WORD_START:
|
||||
if event.type is PipelineEventType.WAKE_WORD_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
elif event.type == PipelineEventType.STT_START:
|
||||
elif event.type is PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||
elif event.type == PipelineEventType.INTENT_START:
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type == PipelineEventType.TTS_START:
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
elif event.type == PipelineEventType.RUN_END:
|
||||
elif event.type is PipelineEventType.RUN_END:
|
||||
if not self._run_has_tts:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
self.on_pipeline_event(event)
|
||||
|
||||
def _set_state(self, state: AssistSatelliteState):
|
||||
"""Set the entity's state."""
|
||||
self._attr_state = state
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
{
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"name": "Assist satellite",
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,11 @@ from homeassistant.components.assist_pipeline import (
|
|||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteEntity
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
@ -78,6 +82,12 @@ async def async_setup_entry(
|
|||
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
|
||||
"""Assist satellite for VoIP devices."""
|
||||
|
||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||
_attr_translation_key = "assist_satellite"
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
|
@ -108,8 +118,8 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
if self.voip_device.protocol == self:
|
||||
self.voip_device.protocol = None
|
||||
assert self.voip_device.protocol == self
|
||||
self.voip_device.protocol = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# VoIP
|
||||
|
@ -188,8 +198,6 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
super().on_pipeline_event(event)
|
||||
|
||||
if event.type == PipelineEventType.STT_END:
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
self._processing_tone_done.clear()
|
||||
|
|
|
@ -13,10 +13,10 @@
|
|||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
1
tests/components/assist_satellite/__init__.py
Normal file
1
tests/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""Tests for Assist Satellite."""
|
104
tests/components/assist_satellite/conftest.py
Normal file
104
tests/components/assist_satellite/conftest.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
"""Test helpers for Assist Satellite."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
DOMAIN as AS_DOMAIN,
|
||||
AssistSatelliteEntity,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test_satellite"
|
||||
|
||||
|
||||
class MockAssistSatellite(AssistSatelliteEntity):
|
||||
"""Mock Assist Satellite Entity."""
|
||||
|
||||
_attr_name = "Test Entity"
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
self.events = []
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity() -> MockAssistSatellite:
|
||||
"""Mock Assist Satellite Entity."""
|
||||
return MockAssistSatellite()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Mock config entry."""
|
||||
entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Initialize components."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test tts platform via config entry."""
|
||||
async_add_entities([entity])
|
||||
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
86
tests/components/assist_satellite/test_entity.py
Normal file
86
tests/components/assist_satellite/test_entity.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
"""Test the Assist Satellite entity."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteState
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
ENTITY_ID = "assist_satellite.test_entity"
|
||||
|
||||
|
||||
async def test_entity_state(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test entity state represent events."""
|
||||
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
context = Context()
|
||||
audio_stream = object()
|
||||
|
||||
entity.async_set_context(context)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||
) as mock_start_pipeline:
|
||||
await entity._async_accept_pipeline_from_satellite(audio_stream)
|
||||
|
||||
assert mock_start_pipeline.called
|
||||
kwargs = mock_start_pipeline.call_args[1]
|
||||
assert kwargs["context"] is context
|
||||
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
|
||||
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
assert kwargs["stt_stream"] is audio_stream
|
||||
assert kwargs["pipeline_id"] is None
|
||||
assert kwargs["device_id"] is None
|
||||
assert kwargs["tts_audio_output"] == "wav"
|
||||
assert kwargs["wake_word_phrase"] is None
|
||||
assert kwargs["audio_settings"] == AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||
)
|
||||
assert kwargs["start_stage"] == PipelineStage.STT
|
||||
assert kwargs["end_stage"] == PipelineStage.TTS
|
||||
|
||||
for event_type, expected_state in (
|
||||
(PipelineEventType.RUN_START, STATE_UNKNOWN),
|
||||
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
|
||||
):
|
||||
kwargs["event_callback"](PipelineEvent(event_type, {}))
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state.state == expected_state, event_type
|
||||
|
||||
entity.tts_response_finished()
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
Loading…
Add table
Reference in a new issue