Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
Paulus Schoutsen
898bb56519 Address comments 2024-08-30 13:16:13 +00:00
Michael Hansen
1a6affc426 Resolve media id if present 2024-08-29 16:21:47 -05:00
Michael Hansen
93cc266b06 Dynamically set supported features 2024-08-29 16:06:52 -05:00
Michael Hansen
f0c49b3995 Add async_announce 2024-08-29 15:41:23 -05:00
Michael Hansen
d375bfaefe Update tests 2024-08-29 13:18:34 -05:00
Michael Hansen
7fe4a52d59 Before refactor 2024-08-29 13:13:10 -05:00
Michael Hansen
a51de1df3c
Incorporate assist satellite entity feedback (#124727)
* Incorporate feedback

* Raise value error

* Clean up entity description

* More cleanup

* Move some things around

* Add a basic test

* Whatever

* Update CODEOWNERS

* Add tests

* Test tts response finished

* Fix test

* Wrong place

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
2024-08-29 01:03:48 +02:00
Michael Hansen
644427ecc7
Add Assist satellite entity + VoIP (#123830)
* Add assist_satellite and implement VoIP

* Fix tests

* More tests

* Improve test

* Update entity state

* Set state correctly

* Move more functionality into base class

* Move RTP protocol into entity

* Fix tests

* Remove string

* Move to util method

* Align states better with pipeline events

* Remove public async_get_satellite_entity

* WAITING_FOR_WAKE_WORD

* Pass entity ids for pipeline/vad sensitivity

* Remove connect/disconnect

* Clean up

* Final cleanup
2024-08-25 16:19:36 +02:00
36 changed files with 2224 additions and 672 deletions

View file

@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
/tests/components/aseko_pool_live/ @milanmeu
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @synesthesiam
/tests/components/assist_satellite/ @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL

View file

@ -16,6 +16,7 @@ from .const import (
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
OPTION_PREFERRED,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
@ -57,6 +58,7 @@ __all__ = (
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
"OPTION_PREFERRED",
"SAMPLES_PER_CHUNK",
"SAMPLE_RATE",
"SAMPLE_WIDTH",

View file

@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
MS_PER_CHUNK = 10
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred"

View file

@ -504,7 +504,7 @@ class AudioSettings:
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
silence_seconds: float = 0.5
silence_seconds: float = 0.7
"""Seconds of silence after voice command has ended."""
def __post_init__(self) -> None:
@ -906,6 +906,8 @@ class PipelineRun:
metadata,
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
)
except (asyncio.CancelledError, TimeoutError):
raise # expected
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
raise SpeechToTextError(

View file

@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@callback
def get_chosen_pipeline(

View file

@ -0,0 +1,65 @@
"""Base class for assist satellite entities."""
import logging
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
from .websocket_api import async_register_websocket_api
__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
"AssistSatelliteEntityFeature",
]
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
_LOGGER, DOMAIN, hass
)
await component.async_setup(config)
async_register_websocket_api(hass)
component.async_register_entity_service(
"announce",
vol.All(
vol.Schema(
{
vol.Optional("text"): str,
vol.Optional("media"): str,
}
),
cv.has_at_least_one_key("text", "media"),
),
"async_annonuce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View file

@ -0,0 +1,3 @@
"""Constants for assist satellite."""
DOMAIN = "assist_satellite"

View file

@ -0,0 +1,283 @@
"""Assist satellite entity."""
from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
import logging
import time
from typing import Any, Final
from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .errors import SatelliteBusyError
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
_LOGGER = logging.getLogger(__name__)
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes assist satellite entities."""
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
entity_description: AssistSatelliteEntityDescription
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = None
_attr_supported_features = AssistSatelliteEntityFeature(0)
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_is_announcing: bool = False
_tts_finished_event: asyncio.Event | None = None
_wake_word_future: asyncio.Future[str | None] | None = None
@property
def is_announcing(self) -> bool:
"""Returns true if currently announcing."""
return self._is_announcing
async def async_announce(
self,
text: str | None = None,
media_id: str | None = None,
) -> None:
"""Play an announcement on the satellite.
If media_id is not provided, text is synthesized to
audio with the selected pipeline.
Calls _internal_async_announce with media id and expects it to block
until the announcement is completed.
"""
if text is None:
text = ""
if not media_id:
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
pipeline = async_get_pipeline(self.hass, pipeline_id)
tts_options: dict[str, Any] = {}
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
media_id = tts_generate_media_source_id(
self.hass,
text,
engine=pipeline.tts_engine,
language=pipeline.tts_language,
options=tts_options,
)
if media_source.is_media_source_id(media_id):
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url
# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)
if self._is_announcing:
raise SatelliteBusyError
self._is_announcing = True
try:
# Block until announcement is finished
await self._internal_async_announce(media_id)
finally:
self._is_announcing = False
async def _internal_async_announce(self, media_id: str) -> None:
"""Announce the media URL on the satellite and returns when finished."""
raise NotImplementedError
@property
def is_intercepting_wake_word(self) -> bool:
"""Return true if next wake word will be intercepted."""
return (self._wake_word_future is not None) and (
not self._wake_word_future.cancelled()
)
async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
Returns the detected wake word phrase or None.
"""
if self._wake_word_future is not None:
raise SatelliteBusyError
# Will cause next wake word to be intercepted in
# _async_accept_pipeline_from_satellite
self._wake_word_future = asyncio.Future()
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
try:
return await self._wake_word_future
finally:
self._wake_word_future = None
return None
async def _async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
pipeline_entity_id: str | None = None,
vad_sensitivity_entity_id: str | None = None,
wake_word_phrase: str | None = None,
) -> None:
"""Trigger an Assist pipeline in Home Assistant from a satellite."""
if self.is_intercepting_wake_word:
# Intercepting wake word and immediately end pipeline
_LOGGER.debug(
"Intercepted wake word: %s (entity_id=%s)",
wake_word_phrase,
self.entity_id,
)
assert self._wake_word_future is not None
self._wake_word_future.set_result(wake_word_phrase)
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
vad_sensitivity = vad.VadSensitivity.DEFAULT
if vad_sensitivity_entity_id:
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise ValueError("VAD sensitivity entity not found")
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
device_id = self.registry_entry.device_id if self.registry_entry else None
# Refresh context if necessary
if (
(self._context is None)
or (self._context_set is None)
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
):
self.async_set_context(Context())
assert self._context is not None
# Reset conversation id if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = ulid.ulid()
# Update timeout
self._conversation_id_time = time.monotonic()
# Set entity state based on pipeline events
self._tts_finished_event = None
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=pipeline_id,
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
),
start_stage=start_stage,
end_stage=end_stage,
)
@abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type is PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._tts_finished_event = asyncio.Event()
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type is PipelineEventType.RUN_END:
if self._tts_finished_event is None:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
self.on_pipeline_event(event)
def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state
self.async_write_ha_state()
def tts_response_finished(self) -> None:
"""Tell entity that the text-to-speech response has finished playing."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
if self._tts_finished_event is not None:
self._tts_finished_event.set()
def _resolve_pipeline(self, pipeline_entity_id: str | None) -> str | None:
"""Resolve pipeline from select entity to id."""
if not pipeline_entity_id:
return None
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
raise ValueError("Pipeline entity not found")
if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
return pipeline.id
return None

View file

@ -0,0 +1,11 @@
"""Errors for assist satellite."""
from homeassistant.exceptions import HomeAssistantError
class AssistSatelliteError(HomeAssistantError):
"""Base class for assist satellite errors."""
class SatelliteBusyError(AssistSatelliteError):
"""Satellite is busy and cannot handle the request."""

View file

@ -0,0 +1,7 @@
{
"entity_component": {
"_": {
"default": "mdi:microphone-message"
}
}
}

View file

@ -0,0 +1,9 @@
{
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@synesthesiam"],
"config_flow": false,
"dependencies": ["assist_pipeline", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity"
}

View file

@ -0,0 +1,26 @@
"""Models for assist satellite."""
from enum import IntFlag, StrEnum
class AssistSatelliteState(StrEnum):
"""Valid states of an Assist satellite entity."""
LISTENING_WAKE_WORD = "listening_wake_word"
"""Device is streaming audio for wake word detection to Home Assistant."""
LISTENING_COMMAND = "listening_command"
"""Device is streaming audio with the voice command to Home Assistant."""
PROCESSING = "processing"
"""Home Assistant is processing the voice command."""
RESPONDING = "responding"
"""Device is speaking the response."""
class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""
ANNOUNCE = 1
"""Device supports remotely triggered announcements."""

View file

@ -0,0 +1,13 @@
{
"entity_component": {
"_": {
"name": "Assist satellite",
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
}
}

View file

@ -0,0 +1,42 @@
"""Assist satellite Websocket API."""
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_component import EntityComponent
from .const import DOMAIN
from .entity import AssistSatelliteEntity
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/intercept_wake_word",
vol.Required("entity_id"): str,
}
)
@websocket_api.async_response
async def websocket_intercept_wake_word(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
return
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})

View file

@ -0,0 +1,509 @@
"""Support for assist satellites in ESPHome."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable
from functools import partial
import io
import logging
import socket
from typing import Any, cast
import wave
from aioesphomeapi import (
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineStage,
)
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import EsphomeAssistEntity
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
}
)
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
EsphomeEnumMapper(
{
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
}
)
)
async def async_setup_entry(
hass: HomeAssistant,
entry: ESPHomeConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Assist satellite entity."""
entry_data = entry.runtime_data
assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities(
[
EsphomeAssistSatellite(hass, entry, entry_data),
]
)
class EsphomeAssistSatellite(
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
):
"""Satellite running ESPHome."""
entity_description = assist_satellite.AssistSatelliteEntityDescription(
key="assist_satellite",
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
entry_data: RuntimeEntryData,
) -> None:
"""Initialize satellite."""
super().__init__(entry_data)
self.hass = hass
self.config_entry = config_entry
self.entry_data = entry_data
self.cli = self.entry_data.client
self._is_running: bool = True
self._pipeline_task: asyncio.Task | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._tts_streaming_task: asyncio.Task | None = None
self._udp_server: VoiceAssistantUDPServer | None = None
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.API_AUDIO:
# TCP audio
self.entry_data.disconnect_callbacks.add(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
)
)
else:
# UDP audio
self.entry_data.disconnect_callbacks.add(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
)
)
if feature_flags & VoiceAssistantFeature.TIMERS:
# Device supports timers
assert (self.registry_entry is not None) and (
self.registry_entry.device_id is not None
)
self.entry_data.disconnect_callbacks.add(
async_register_timer_handler(
self.hass, self.registry_entry.device_id, self.handle_timer_event
)
)
if feature_flags & VoiceAssistantFeature.ANNOUNCE:
# Device supports announcements
self._attr_supported_features |= (
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
self._is_running = False
self._stop_pipeline()
async def _internal_async_announce(self, media_id: str) -> None:
self.cli.send_voice_assistant_announce(media_id)
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
return
data_to_send: dict[str, Any] = {}
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
self.entry_data.async_set_assist_pipeline_state(True)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
assert event.data is not None
data_to_send = {
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
tts_output = event.data["tts_output"]
if tts_output:
path = tts_output["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.SPEAKER:
media_id = tts_output["media_id"]
self._tts_streaming_task = (
self.config_entry.async_create_background_task(
self.hass,
self._stream_tts_audio(media_id),
"esphome_voice_assistant_tts",
)
)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert event.data is not None
if not event.data["wake_word_output"]:
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
data_to_send = {
"code": "no_wake_word",
"message": "No wake word detected",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
self.cli.send_voice_assistant_event(event_type, data_to_send)
async def handle_pipeline_start(
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
"""Handle pipeline run request."""
# Clear audio queue
while not self._audio_queue.empty():
await self._audio_queue.get()
if self._tts_streaming_task is not None:
# Cancel current TTS response
self._tts_streaming_task.cancel()
self._tts_streaming_task = None
# API or UDP output audio
port: int = 0
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
feature_flags & VoiceAssistantFeature.API_AUDIO
):
port = await self._start_udp_server()
_LOGGER.debug("Started UDP server on port %s", port)
# Get entity ids for pipeline and finished speaking detection
ent_reg = er.async_get(self.hass)
pipeline_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-pipeline",
)
vad_sensitivity_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
)
# Device triggered pipeline (wake word, etc.)
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
end_stage = PipelineStage.TTS
# Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self.entry_data.async_set_assist_pipeline_state(True)
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._async_accept_pipeline_from_satellite(
audio_stream=self._wrap_audio_stream(),
start_stage=start_stage,
end_stage=end_stage,
pipeline_entity_id=pipeline_entity_id,
vad_sensitivity_entity_id=vad_sensitivity_entity_id,
wake_word_phrase=wake_word_phrase,
),
"esphome_assist_satellite_pipeline",
)
self._pipeline_task.add_done_callback(
lambda _future: self.handle_pipeline_finished()
)
return port
async def handle_audio(self, data: bytes) -> None:
"""Handle incoming audio chunk from API."""
self._audio_queue.put_nowait(data)
async def handle_pipeline_stop(self) -> None:
"""Handle request for pipeline to stop."""
self._stop_pipeline()
def handle_pipeline_finished(self) -> None:
"""Handle when pipeline has finished running."""
self.entry_data.async_set_assist_pipeline_state(False)
self._stop_udp_server()
_LOGGER.debug("Pipeline finished")
def handle_timer_event(
self, event_type: TimerEventType, timer_info: TimerInfo
) -> None:
"""Handle timer events."""
try:
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
except KeyError:
_LOGGER.debug("Received unknown timer event type: %s", event_type)
return
self.cli.send_voice_assistant_timer_event(
native_event_type,
timer_info.id,
timer_info.name,
timer_info.created_seconds,
timer_info.seconds_left,
timer_info.is_active,
)
async def _stream_tts_audio(
self,
media_id: str,
sample_rate: int = 16000,
sample_width: int = 2,
sample_channels: int = 1,
samples_per_chunk: int = 512,
) -> None:
"""Stream TTS audio chunks to device via API or UDP."""
self.cli.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
try:
if not self._is_running:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
or (wav_file.getnchannels() != sample_channels)
):
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
return
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
while True:
chunk = wav_file.readframes(samples_per_chunk)
if not chunk:
break
if self._udp_server is not None:
self._udp_server.send_audio_bytes(chunk)
else:
self.cli.send_voice_assistant_audio(chunk)
# Wait for 90% of the duration of the audio that was
# sent for it to be played. This will overrun the
# device's buffer for very long audio, so using a media
# player is preferred.
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
seconds_in_chunk = samples_in_chunk / sample_rate
await asyncio.sleep(seconds_in_chunk * 0.9)
except asyncio.CancelledError:
return # Don't trigger state change
finally:
self.cli.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
# State change
self.tts_response_finished()
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
"""Yield audio chunks from the queue until None."""
while True:
chunk = await self._audio_queue.get()
if not chunk:
break
yield chunk
def _stop_pipeline(self) -> None:
"""Request pipeline to be stopped."""
self._audio_queue.put_nowait(None)
_LOGGER.debug("Requested pipeline stop")
async def _start_udp_server(self) -> int:
"""Start a UDP server on a random free port."""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", 0)) # random free port
(
_transport,
protocol,
) = await asyncio.get_running_loop().create_datagram_endpoint(
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
)
assert isinstance(protocol, VoiceAssistantUDPServer)
self._udp_server = protocol
# Return port
return cast(int, sock.getsockname()[1])
def _stop_udp_server(self) -> None:
"""Stop the UDP server if it's running."""
if self._udp_server is None:
return
try:
self._udp_server.close()
finally:
self._udp_server = None
_LOGGER.debug("Stopped UDP server")
# -----------------------------------------------------------------------------
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the audio queue."""
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
def __init__(
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
) -> None:
"""Initialize protocol."""
super().__init__(*args, **kwargs)
self._audio_queue = audio_queue
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if self.remote_addr is None:
self.remote_addr = addr
self._audio_queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
# Stop pipeline
self._audio_queue.put_nowait(None)
def close(self) -> None:
"""Close the receiver."""
if self.transport is not None:
self.transport.close()
self.remote_addr = None
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via UDP."""
if self.transport is None:
_LOGGER.error("No transport to send audio to")
return
if self.remote_addr is None:
_LOGGER.error("No address to send audio to")
return
self.transport.sendto(data, self.remote_addr)

View file

@ -27,12 +27,12 @@ from awesomeversion import AwesomeVersion
import voluptuous as vol
from homeassistant.components import tag, zeroconf
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.const import (
ATTR_DEVICE_ID,
CONF_MODE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_LOGGING_CHANGED,
Platform,
)
from homeassistant.core import (
Event,
@ -77,7 +77,6 @@ from .voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantPipeline,
VoiceAssistantUDPPipeline,
handle_timer_event,
)
_LOGGER = logging.getLogger(__name__)
@ -500,29 +499,14 @@ class ESPHomeManager:
)
)
flags = device_info.voice_assistant_feature_flags_compat(api_version)
if flags:
if flags & VoiceAssistantFeature.API_AUDIO:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
handle_audio=self._handle_audio,
)
)
else:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
)
)
if flags & VoiceAssistantFeature.TIMERS:
entry_data.disconnect_callbacks.add(
async_register_timer_handler(
hass, self.device_id, partial(handle_timer_event, cli)
)
)
if device_info.voice_assistant_feature_flags_compat(api_version) and (
Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms
):
# Create assist satellite entity
await self.hass.config_entries.async_forward_entry_setups(
self.entry, [Platform.ASSIST_SATELLITE]
)
entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE)
cli.subscribe_states(entry_data.async_update_state)
cli.subscribe_service_calls(self.async_on_service_call)
@ -844,4 +828,5 @@ async def cleanup_instance(
cleanup_callback()
await data.async_cleanup()
await data.client.disconnect()
return data

View file

@ -59,6 +59,17 @@
}
},
"entity": {
"assist_satellite": {
"assist_satellite": {
"name": "[%key:component::assist_satellite::entity_component::_::name%]",
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
}
}
},
"binary_sensor": {
"assist_in_progress": {
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"

View file

@ -20,6 +20,7 @@ from .devices import VoIPDevices
from .voip import HassVoipDatagramProtocol
PLATFORMS = (
Platform.ASSIST_SATELLITE,
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,

View file

@ -0,0 +1,306 @@
"""Assist satellite entity for VoIP integration."""
from __future__ import annotations
import asyncio
from enum import IntFlag
from functools import partial
import io
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Final
import wave
from voip_utils import RtpDatagramProtocol
from homeassistant.components import tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteState,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.async_ import queue_to_iterable
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity
if TYPE_CHECKING:
from . import DomainData
_LOGGER = logging.getLogger(__name__)
_PIPELINE_TIMEOUT_SEC: Final = 30
class Tones(IntFlag):
"""Feedback tones for specific events."""
LISTENING = 1
PROCESSING = 2
ERROR = 4
_TONE_FILENAMES: dict[Tones, str] = {
Tones.LISTENING: "tone.pcm",
Tones.PROCESSING: "processing.pcm",
Tones.ERROR: "error.pcm",
}
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP Assist satellite entity."""
domain_data: DomainData = hass.data[DOMAIN]
@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
domain_data.devices.async_add_new_device_listener(async_add_device)
entities: list[VoIPEntity] = [
VoipAssistSatellite(hass, device, config_entry)
for device in domain_data.devices
]
async_add_entities(entities)
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
"""Assist satellite for VoIP devices."""
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_has_entity_name = True
_attr_name = None
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
def __init__(
self,
hass: HomeAssistant,
voip_device: VoIPDevice,
config_entry: ConfigEntry,
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
) -> None:
"""Initialize an Assist satellite."""
VoIPEntity.__init__(self, voip_device)
AssistSatelliteEntity.__init__(self)
RtpDatagramProtocol.__init__(self)
self.config_entry = config_entry
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._audio_chunk_timeout: float = 2.0
self._pipeline_task: asyncio.Task | None = None
self._pipeline_had_error: bool = False
self._tts_done = asyncio.Event()
self._tts_extra_timeout: float = 1.0
self._tone_bytes: dict[Tones, bytes] = {}
self._tones = tones
self._processing_tone_done = asyncio.Event()
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
self.voip_device.protocol = self
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
assert self.voip_device.protocol == self
self.voip_device.protocol = None
# -------------------------------------------------------------------------
# VoIP
# -------------------------------------------------------------------------
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
self.voip_device.set_is_active(True)
# Play listening tone at the start of each cycle
await self._play_tone(Tones.LISTENING, silence_before=0.2)
try:
self._tts_done.clear()
# Run pipeline with a timeout
_LOGGER.debug("Starting pipeline")
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
audio_stream=queue_to_iterable(
self._audio_queue, timeout=self._audio_chunk_timeout
),
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
self.hass
),
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
self.hass
),
)
if self._pipeline_had_error:
self._pipeline_had_error = False
await self._play_tone(Tones.ERROR)
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except (asyncio.CancelledError, TimeoutError):
# Expected after caller hangs up
_LOGGER.debug("Pipeline cancelled or timed out")
self.disconnect()
self._clear_audio_queue()
finally:
self.voip_device.set_is_active(False)
# Allow pipeline to run again
self._pipeline_task = None
def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type == PipelineEventType.STT_END:
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
self._processing_tone_done.clear()
self.config_entry.async_create_background_task(
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
)
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.config_entry.async_create_background_task(
self.hass,
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS when pipeline is finished.
self._pipeline_had_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return # not connected
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
# Don't overlap TTS and processing beep
await self._processing_tone_done.wait()
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
# Update satellite state
self.tts_response_finished()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
"""Play a tone as feedback to the user if it's enabled."""
if (self._tones & tone) != tone:
return # not enabled
if tone not in self._tone_bytes:
# Do I/O in executor
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
self._load_pcm,
_TONE_FILENAMES[tone],
)
await self._async_send_audio(
self._tone_bytes[tone],
silence_before=silence_before,
)
if tone == Tones.PROCESSING:
self._processing_tone_done.set()
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()

View file

@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
"""Call when entity about to be added to hass."""
await super().async_added_to_hass()
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
self.async_on_remove(
self.voip_device.async_listen_update(self._is_active_changed)
)
@callback
def _is_active_changed(self, device: VoIPDevice) -> None:
"""Call when active state changed."""
self._attr_is_on = self._device.is_active
self._attr_is_on = self.voip_device.is_active
self.async_write_ha_state()

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterator
from dataclasses import dataclass, field
from voip_utils import CallInfo
from voip_utils import CallInfo, VoipDatagramProtocol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Event, HomeAssistant, callback
@ -22,6 +22,7 @@ class VoIPDevice:
device_id: str
is_active: bool = False
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
protocol: VoipDatagramProtocol | None = None
@callback
def set_is_active(self, active: bool) -> None:
@ -56,6 +57,18 @@ class VoIPDevice:
return False
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for pipeline select."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for VAD sensitivity."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
)
class VoIPDevices:
"""Class to store devices."""

View file

@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
_attr_has_entity_name = True
_attr_should_poll = False
def __init__(self, device: VoIPDevice) -> None:
def __init__(self, voip_device: VoIPDevice) -> None:
"""Initialize VoIP entity."""
self._device = device
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
self.voip_device = voip_device
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, device.voip_id)},
identifiers={(DOMAIN, voip_device.voip_id)},
)

View file

@ -3,7 +3,7 @@
"name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline"],
"dependencies": ["assist_pipeline", "assist_satellite"],
"documentation": "https://www.home-assistant.io/integrations/voip",
"iot_class": "local_push",
"quality_scale": "internal",

View file

@ -10,6 +10,17 @@
}
},
"entity": {
"assist_satellite": {
"assist_satellite": {
"name": "[%key:component::assist_satellite::entity_component::_::name%]",
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
}
}
},
"binary_sensor": {
"call_in_progress": {
"name": "Call in progress"

View file

@ -3,15 +3,11 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable, MutableSequence, Sequence
from functools import partial
import io
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING
import wave
from voip_utils import (
CallInfo,
@ -21,33 +17,19 @@ from voip_utils import (
VoipDatagramProtocol,
)
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline import (
Pipeline,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
async_get_pipeline,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.audio_enhancer import (
AudioEnhancer,
MicroVadEnhancer,
)
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
from homeassistant.util.ulid import ulid_now
from homeassistant.core import HomeAssistant
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
from .devices import VoIPDevices
_LOGGER = logging.getLogger(__name__)
@ -60,11 +42,8 @@ def make_protocol(
) -> VoipDatagramProtocol:
"""Plays a pre-recorded message if pipeline is misconfigured."""
voip_device = devices.async_get_or_create(call_info)
pipeline_id = pipeline_select.get_chosen_pipeline(
hass,
DOMAIN,
voip_device.voip_id,
)
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
try:
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
except PipelineNotFound:
@ -83,22 +62,18 @@ def make_protocol(
rtcp_state=rtcp_state,
)
vad_sensitivity = pipeline_select.get_vad_sensitivity(
hass,
DOMAIN,
voip_device.voip_id,
)
if (protocol := voip_device.protocol) is None:
raise ValueError("VoIP satellite not found")
# Pipeline is properly configured
return PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(user_id=devices.config_entry.data["user"]),
opus_payload_type=call_info.opus_payload_type,
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
rtcp_state=rtcp_state,
)
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol.rtcp_state = rtcp_state
if protocol.rtcp_state is not None:
# Automatically disconnect when BYE is received over RTCP
protocol.rtcp_state.bye_callback = protocol.disconnect
return protocol
class HassVoipDatagramProtocol(VoipDatagramProtocol):
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
await self._closed_event.wait()
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
"""Run a voice assistant pipeline in a loop for a VoIP call."""
def __init__(
self,
hass: HomeAssistant,
language: str,
voip_device: VoIPDevice,
context: Context,
opus_payload_type: int,
pipeline_timeout: float = 30.0,
audio_timeout: float = 2.0,
buffered_chunks_before_speech: int = 100,
listening_tone_enabled: bool = True,
processing_tone_enabled: bool = True,
error_tone_enabled: bool = True,
tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0,
silence_seconds: float = 1.0,
rtcp_state: RtcpState | None = None,
) -> None:
"""Set up pipeline RTP server."""
super().__init__(
rate=RATE,
width=WIDTH,
channels=CHANNELS,
opus_payload_type=opus_payload_type,
rtcp_state=rtcp_state,
)
self.hass = hass
self.language = language
self.voip_device = voip_device
self.pipeline: Pipeline | None = None
self.pipeline_timeout = pipeline_timeout
self.audio_timeout = audio_timeout
self.buffered_chunks_before_speech = buffered_chunks_before_speech
self.listening_tone_enabled = listening_tone_enabled
self.processing_tone_enabled = processing_tone_enabled
self.error_tone_enabled = error_tone_enabled
self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout
self.silence_seconds = silence_seconds
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context
self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
self._tts_done = asyncio.Event()
self._session_id: str | None = None
self._tone_bytes: bytes | None = None
self._processing_bytes: bytes | None = None
self._error_bytes: bytes | None = None
self._pipeline_error: bool = False
def connection_made(self, transport):
"""Server is ready."""
super().connection_made(transport)
self.voip_device.set_is_active(True)
def connection_lost(self, exc):
"""Handle connection is lost or closed."""
super().connection_lost(exc)
self.voip_device.set_is_active(False)
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.hass.async_create_background_task(
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
if self._session_id is None:
self._session_id = ulid_now()
# Play listening tone at the start of each cycle
if self.listening_tone_enabled:
await self._play_listening_tone()
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
audio_enhancer = MicroVadEnhancer(0, 0, True)
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)
speech_detected = await self._wait_for_speech(
segmenter,
audio_enhancer,
chunk_buffer,
)
if not speech_detected:
_LOGGER.debug("No speech detected")
return
_LOGGER.debug("Starting pipeline")
self._tts_done.clear()
async def stt_stream():
try:
async for chunk in self._segment_audio(
segmenter,
audio_enhancer,
chunk_buffer,
):
yield chunk
if self.processing_tone_enabled:
await self._play_processing_tone()
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Audio timeout")
self._session_id = None
self.disconnect()
finally:
self._clear_audio_queue()
# Run pipeline with a timeout
async with asyncio.timeout(self.pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="wav",
)
if self._pipeline_error:
self._pipeline_error = False
if self.error_tone_enabled:
await self._play_error_tone()
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Pipeline timeout")
self._session_id = None
self.disconnect()
finally:
# Allow pipeline to run again
self._pipeline_task = None
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: MutableSequence[bytes],
):
"""Buffer audio chunks until speech is detected.
Returns True if speech was detected, False otherwise.
"""
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
chunk_buffer.append(chunk)
segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
)
if segmenter.in_command:
# Buffer until command starts
if len(vad_buffer) > 0:
chunk_buffer.append(vad_buffer.bytes())
return True
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
return False
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished."""
# Buffered chunks first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
if not segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
):
# Voice command is finished
break
yield chunk
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
def _clear_audio_queue(self) -> None:
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def _event_callback(self, event: PipelineEvent):
if not event.data:
return
if event.type == PipelineEventType.INTENT_END:
# Capture conversation id
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
tts_output = event.data["tts_output"]
if tts_output:
media_id = tts_output["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS
self._pipeline_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_listening_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is listening."""
if self._tone_bytes is None:
# Do I/O in executor
self._tone_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"tone.pcm",
)
await self._async_send_audio(
self._tone_bytes,
silence_before=self.tone_delay,
)
async def _play_processing_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is processing the voice command."""
if self._processing_bytes is None:
# Do I/O in executor
self._processing_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"processing.pcm",
)
await self._async_send_audio(self._processing_bytes)
async def _play_error_tone(self) -> None:
"""Play a tone to indicate a pipeline error occurred."""
if self._error_bytes is None:
# Do I/O in executor
self._error_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"error.pcm",
)
await self._async_send_audio(self._error_bytes)
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()
class PreRecordMessageProtocol(RtpDatagramProtocol):
"""Plays a pre-recorded message on a loop."""

View file

@ -41,6 +41,7 @@ class Platform(StrEnum):
AIR_QUALITY = "air_quality"
ALARM_CONTROL_PANEL = "alarm_control_panel"
ASSIST_SATELLITE = "assist_satellite"
BINARY_SENSOR = "binary_sensor"
BUTTON = "button"
CALENDAR = "calendar"

View file

@ -5,22 +5,28 @@ from __future__ import annotations
from asyncio import (
AbstractEventLoop,
Future,
Queue,
Semaphore,
Task,
TimerHandle,
gather,
get_running_loop,
timeout as async_timeout,
)
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
import concurrent.futures
import logging
import threading
from typing import Any
from typing_extensions import TypeVar
_LOGGER = logging.getLogger(__name__)
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
_DataT = TypeVar("_DataT", default=Any)
def create_eager_task[_T](
coro: Coroutine[Any, Any, _T],
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
"""Return a list of scheduled TimerHandles."""
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
return handles
async def queue_to_iterable(
queue: Queue[_DataT], timeout: float | None = None
) -> AsyncIterable[_DataT]:
"""Stream items from a queue until None with an optional timeout per item."""
if timeout is None:
while (item := await queue.get()) is not None:
yield item
else:
async with async_timeout(timeout):
item = await queue.get()
while item is not None:
yield item
async with async_timeout(timeout):
item = await queue.get()

View file

@ -0,0 +1 @@
"""Tests for Assist Satellite."""

View file

@ -0,0 +1,106 @@
"""Test helpers for Assist Satellite."""
from unittest.mock import Mock
import pytest
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
TEST_DOMAIN = "test_satellite"
class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""
_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)
@pytest.fixture
def entity() -> MockAssistSatellite:
"""Mock Assist Satellite Entity."""
return MockAssistSatellite()
@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Mock config entry."""
entry = MockConfigEntry(domain=TEST_DOMAIN)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_components(
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry

View file

@ -0,0 +1,88 @@
"""Test the Assist Satellite entity."""
from unittest.mock import patch
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
vad,
)
from homeassistant.components.assist_satellite import AssistSatelliteState
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant
from .conftest import MockAssistSatellite
ENTITY_ID = "assist_satellite.test_entity"
async def test_entity_state(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test entity state represent events."""
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
context = Context()
audio_stream = object()
entity.async_set_context(context)
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
) as mock_start_pipeline:
await entity._async_accept_pipeline_from_satellite(audio_stream) # type: ignore[arg-type]
assert mock_start_pipeline.called
kwargs = mock_start_pipeline.call_args[1]
assert kwargs["context"] is context
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav"
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
)
assert kwargs["start_stage"] == PipelineStage.STT
assert kwargs["end_stage"] == PipelineStage.TTS
for event_type, expected_state in (
(PipelineEventType.RUN_START, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
(PipelineEventType.RUN_END, AssistSatelliteState.RESPONDING),
):
kwargs["event_callback"](PipelineEvent(event_type, {}))
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == expected_state, event_type
entity.tts_response_finished()
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD

View file

@ -0,0 +1,181 @@
"""Test the Assist Satellite websocket API."""
import asyncio
from collections.abc import AsyncIterable
from unittest.mock import ANY, patch
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineStage,
)
from homeassistant.components.assist_satellite import AssistSatelliteEntityFeature
from homeassistant.components.media_source import PlayMedia
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from .conftest import MockAssistSatellite
from tests.typing import WebSocketGenerator
ENTITY_ID = "assist_satellite.test_entity"
async def audio_stream() -> AsyncIterable[bytes]:
"""Empty audio stream."""
yield b""
async def test_intercept_wake_word(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/intercept_wake_word command."""
client = await hass_ws_client(hass)
with (
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.validate",
return_value=None,
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text",
return_value=None,
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_recognize_intent",
return_value=None,
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_text_to_speech",
return_value=None,
),
patch.object(entity, "on_pipeline_event") as mock_on_pipeline_event,
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{"type": "assist_satellite/intercept_wake_word", "entity_id": ENTITY_ID}
)
# Wait for interception to start
while not entity.is_intercepting_wake_word:
await asyncio.sleep(0.01)
# Start a pipeline with a wake word
await entity._async_accept_pipeline_from_satellite(
audio_stream=audio_stream(),
start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS,
wake_word_phrase="test wake word",
)
# Verify that wake word was intercepted
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"wake_word_phrase": "test wake word"}
# Verify that only run end event was sent to pipeline
mock_on_pipeline_event.assert_called_once_with(
PipelineEvent(PipelineEventType.RUN_END, data=None, timestamp=ANY)
)
async def test_announce_not_supported(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/announce command with an entity that doesn't support announcements."""
client = await hass_ws_client(hass)
with patch.object(
entity, "_attr_supported_features", AssistSatelliteEntityFeature(0)
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{
"type": "assist_satellite/announce",
"entity_id": ENTITY_ID,
"media_id": "test media id",
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == ERR_NOT_SUPPORTED
async def test_announce_media_id(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/announce command with media id."""
client = await hass_ws_client(hass)
with (
patch.object(
entity, "_internal_async_announce"
) as mock_internal_async_announce,
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{
"type": "assist_satellite/announce",
"entity_id": ENTITY_ID,
"media_id": "test media id",
}
)
response = await client.receive_json()
assert response["success"]
# Verify media id was passed through
mock_internal_async_announce.assert_called_once_with("test media id")
async def test_announce_text(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/announce command with text."""
client = await hass_ws_client(hass)
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="",
),
patch(
"homeassistant.components.assist_satellite.entity.media_source.async_resolve_media",
return_value=PlayMedia(url="test media id", mime_type=""),
),
patch(
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
return_value="test media id",
),
patch.object(
entity, "_internal_async_announce"
) as mock_internal_async_announce,
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{
"type": "assist_satellite/announce",
"entity_id": ENTITY_ID,
"text": "test text",
}
)
response = await client.receive_json()
assert response["success"]
# Verify media id was passed through
mock_internal_async_announce.assert_called_once_with("test media id")

File diff suppressed because one or more lines are too long

View file

@ -3,15 +3,26 @@
import asyncio
import io
from pathlib import Path
import time
from unittest.mock import AsyncMock, Mock, patch
import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from voip_utils import CallInfo
from homeassistant.components import assist_pipeline, voip
from homeassistant.components.voip.devices import VoIPDevice
from homeassistant.components import assist_pipeline, assist_satellite, voip
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteState,
)
from homeassistant.components.voip import HassVoipDatagramProtocol
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
from homeassistant.const import STATE_OFF, STATE_ON, Platform
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import async_setup_component
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
return wav_io.getvalue()
def async_get_satellite_entity(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> AssistSatelliteEntity | None:
"""Get Assist satellite entity."""
ent_reg = er.async_get(hass)
satellite_entity_id = ent_reg.async_get_entity_id(
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
)
if satellite_entity_id is None:
return None
component: EntityComponent[AssistSatelliteEntity] = hass.data[
assist_satellite.DOMAIN
]
return component.get_entity(satellite_entity_id)
async def test_is_valid_call(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that a call is now allowed from an unknown device."""
assert await async_setup_component(hass, "voip", {})
protocol = HassVoipDatagramProtocol(hass, voip_devices)
assert not protocol.is_valid_call(call_info)
ent_reg = er.async_get(hass)
allowed_call_entity_id = ent_reg.async_get_entity_id(
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
)
assert allowed_call_entity_id is not None
state = hass.states.get(allowed_call_entity_id)
assert state is not None
assert state.state == STATE_OFF
# Allow calls
hass.states.async_set(allowed_call_entity_id, STATE_ON)
assert protocol.is_valid_call(call_info)
async def test_calls_not_allowed(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when calls aren't allowed."""
assert await async_setup_component(hass, "voip", {})
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
# Test the playback
done = asyncio.Event()
played_audio_bytes = b""
def send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be problem.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
protocol.transport = Mock()
protocol.loop_delay = 0
with patch.object(protocol, "send_audio", send_audio):
protocol.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()
async def test_pipeline_not_found(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when a pipeline isn't found."""
assert await async_setup_component(hass, "voip", {})
with patch(
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
):
protocol: PreRecordMessageProtocol = make_protocol(
hass, voip_devices, call_info
)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
async def test_satellite_prepared(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that satellite is prepared for a call."""
assert await async_setup_component(hass, "voip", {})
pipeline = assist_pipeline.Pipeline(
conversation_engine="test",
conversation_language="en",
language="en",
name="test",
stt_engine="test",
stt_language="en",
tts_engine="test",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
with (
patch(
"homeassistant.components.voip.voip.async_get_pipeline",
return_value=pipeline,
),
):
protocol = make_protocol(hass, voip_devices, call_info)
assert protocol == satellite
async def test_pipeline(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
voip_user_id = satellite.config_entry.data["user"]
assert voip_user_id
return 0
# Satellite is muted until a call begins
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
done = asyncio.Event()
# Used to test that audio queue is cleared before pipeline starts
bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
async def async_pipeline_from_audio_stream(
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
):
assert context.user_id == voip_user_id
assert device_id == voip_device.device_id
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
in_command = False
async for chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
assert _chunk != bad_chunk
assert chunk != bad_chunk
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Test empty data
event_callback(
@ -71,6 +229,38 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_START,
data={"engine": "test", "metadata": {}},
)
)
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_START,
data={
"engine": "test",
"language": hass.config.language,
"intent_input": "fake-text",
"conversation_id": None,
"device_id": None,
},
)
)
assert satellite.state == AssistSatelliteState.PROCESSING
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
@ -83,6 +273,21 @@ async def test_pipeline(
)
)
# Fake tts result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_START,
data={
"engine": "test",
"language": hass.config.language,
"voice": "test",
"tts_input": "fake-text",
},
)
)
assert satellite.state == AssistSatelliteState.RESPONDING
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
@ -91,6 +296,18 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.RUN_END
)
)
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
done.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
@ -100,102 +317,56 @@ async def test_pipeline(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
)
rtp_protocol.transport = Mock()
satellite._tones = Tones(0)
satellite.transport = Mock()
satellite.connection_made(satellite.transport)
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
# Ensure audio queue is cleared before pipeline starts
rtp_protocol._audio_queue.put_nowait(bad_chunk)
satellite._audio_queue.put_nowait(bad_chunk)
def send_audio(*args, **kwargs):
# Test finished successfully
done.set()
# Don't send audio
pass
rtp_protocol.send_audio = Mock(side_effect=send_audio)
satellite.send_audio = Mock(side_effect=send_audio)
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes aggressive VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
"""Test timeout during pipeline run."""
assert await async_setup_component(hass, "voip", {})
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
await asyncio.sleep(10)
with (
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
return_value=True,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
pipeline_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
await done.wait()
# Finished speaking
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
async def test_stt_stream_timeout(
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
) -> None:
"""Test timeout in STT stream during pipeline run."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
pass
with patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
audio_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
satellite._tones = Tones(0)
satellite._audio_chunk_timeout = 0.001
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
satellite.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
async def test_tts_timeout(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -278,15 +448,7 @@ async def test_tts_timeout(
tone_bytes = bytes([1, 2, 3, 4])
def send_audio(audio_bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
# Block here to force a timeout in _send_tts
time.sleep(2)
async def async_send_audio(audio_bytes, **kwargs):
async def async_send_audio(audio_bytes: bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
@ -303,37 +465,22 @@ async def test_tts_timeout(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
tts_extra_timeout=0.001,
listening_tone_enabled=True,
processing_tone_enabled=True,
error_tone_enabled=True,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
)
rtp_protocol._tone_bytes = tone_bytes
rtp_protocol._processing_bytes = tone_bytes
rtp_protocol._error_bytes = tone_bytes
rtp_protocol.transport = Mock()
rtp_protocol.send_audio = Mock()
satellite._tts_extra_timeout = 0.001
for tone in Tones:
satellite._tone_bytes[tone] = tone_bytes
original_send_tts = rtp_protocol._send_tts
satellite.transport = Mock()
satellite.send_audio = Mock()
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -342,17 +489,17 @@ async def test_tts_timeout(
done.set()
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -361,26 +508,34 @@ async def test_tts_timeout(
async def test_tts_wrong_extension(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
async def test_tts_wrong_wav_format(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
async def test_empty_tts_output(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -605,37 +754,78 @@ async def test_empty_tts_output(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to finish
async with asyncio.timeout(1):
await rtp_protocol._tts_done.wait()
await satellite._tts_done.wait()
mock_send_tts.assert_not_called()
async def test_pipeline_error(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pipeline error causes the error tone to be played."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
played_audio_bytes = b""
async def async_pipeline_from_audio_stream(*args, **kwargs):
# Fake error
event_callback = kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.ERROR,
data={"code": "error-code", "message": "error message"},
)
)
async def async_send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be error.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
satellite._tones = Tones.ERROR
satellite.transport = Mock()
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for error tone to be played
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -3,8 +3,8 @@
from unittest.mock import Mock, patch
from homeassistant.components import assist_pipeline
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.pipeline import PipelineData
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
timer_handle.cancel()
timer_handle2.cancel()
timer_handle3.cancel()
async def test_queue_to_iterable() -> None:
"""Test queue_to_iterable."""
queue: asyncio.Queue[int | None] = asyncio.Queue()
expected_items = list(range(10))
for i in expected_items:
await queue.put(i)
# Will terminate the stream
await queue.put(None)
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
assert expected_items == actual_items
# Check timeout
assert queue.empty()
# Time out on first item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
await asyncio.sleep(1)
# Check timeout on second item
assert queue.empty()
await queue.put(12345)
# Time out on second item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
if item != 12345:
await asyncio.sleep(1)
assert queue.empty()