Compare commits

...
Sign in to create a new pull request.

6 commits

Author SHA1 Message Date
Paulus Schoutsen
c15e02eee3 Allow intercepting wake words for assist satellite (#124990)
* Allow intercepting wake words for assist satellite

* Add one more test

* Update homeassistant/components/assist_satellite/entity.py

Co-authored-by: Michael Hansen <mike@rhasspy.org>

* Finish test coverage

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
2024-09-04 13:36:58 -05:00
Michael Hansen
dd3cd65bfc Add ESPHome assist satellite entity (#124949)
* Add ESPHome assist satellite entity

* Implement feedback

* Apply suggestions from code review

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
2024-09-04 13:36:58 -05:00
Paulus Schoutsen
ee0c649687 Add announce service to Assist Satellite (#124927)
* Add announce service to assist_satellite

* Add tests

* Update docstring

* Update services.yaml

* pylint on tests now yooooo

* Stub out TTS
2024-09-04 13:36:58 -05:00
Paulus Schoutsen
4e27d8ec78 Minor refactor assist satellite (#124912)
* Extract entity ID for pipeline to property

* Add super calls

* Mock tts
2024-09-04 13:36:58 -05:00
Michael Hansen
a1db430249 Incorporate assist satellite entity feedback (#124727)
* Incorporate feedback

* Raise value error

* Clean up entity description

* More cleanup

* Move some things around

* Add a basic test

* Whatever

* Update CODEOWNERS

* Add tests

* Test tts response finished

* Fix test

* Wrong place

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
2024-09-04 13:36:58 -05:00
Michael Hansen
033bc1bbe5 Add Assist satellite entity + VoIP (#123830)
* Add assist_satellite and implement VoIP

* Fix tests

* More tests

* Improve test

* Update entity state

* Set state correctly

* Move more functionality into base class

* Move RTP protocol into entity

* Fix tests

* Remove string

* Move to util method

* Align states better with pipeline events

* Remove public async_get_satellite_entity

* WAITING_FOR_WAKE_WORD

* Pass entity ids for pipeline/vad sensitivity

* Remove connect/disconnect

* Clean up

* Final cleanup
2024-09-04 13:36:56 -05:00
42 changed files with 3190 additions and 2361 deletions

View file

@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
/tests/components/aseko_pool_live/ @milanmeu
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @synesthesiam
/tests/components/assist_satellite/ @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL

View file

@ -17,6 +17,7 @@ from .const import (
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
OPTION_PREFERRED,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
@ -58,6 +59,7 @@ __all__ = (
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
"OPTION_PREFERRED",
"SAMPLES_PER_CHUNK",
"SAMPLE_RATE",
"SAMPLE_WIDTH",

View file

@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
MS_PER_CHUNK = 10
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred"

View file

@ -504,7 +504,7 @@ class AudioSettings:
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
silence_seconds: float = 0.5
silence_seconds: float = 0.7
"""Seconds of silence after voice command has ended."""
def __post_init__(self) -> None:
@ -906,6 +906,8 @@ class PipelineRun:
metadata,
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
)
except (asyncio.CancelledError, TimeoutError):
raise # expected
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
raise SpeechToTextError(

View file

@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@callback
def get_chosen_pipeline(

View file

@ -0,0 +1,65 @@
"""Base class for assist satellite entities."""
import logging
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
from .websocket_api import async_register_websocket_api
__all__ = [
"DOMAIN",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
"AssistSatelliteEntityFeature",
"AssistSatelliteState",
]
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
_LOGGER, DOMAIN, hass
)
await component.async_setup(config)
component.async_register_entity_service(
"announce",
vol.All(
cv.make_entity_service_schema(
{
vol.Optional("message"): str,
vol.Optional("media_id"): str,
}
),
cv.has_at_least_one_key("message", "media_id"),
),
"async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)
async_register_websocket_api(hass)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View file

@ -0,0 +1,3 @@
"""Constants for assist satellite."""
DOMAIN = "assist_satellite"

View file

@ -0,0 +1,301 @@
"""Assist satellite entity."""
from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
import logging
import time
from typing import Any, Final
from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, callback
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .errors import AssistSatelliteError, SatelliteBusyError
from .models import AssistSatelliteState
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_LOGGER = logging.getLogger(__name__)
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes binary sensor entities."""
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
entity_description: AssistSatelliteEntityDescription
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = None
_attr_pipeline_entity_id: str | None = None
_attr_vad_sensitivity_entity_id: str | None = None
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False
_is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
@property
def pipeline_entity_id(self) -> str | None:
"""Entity ID of the pipeline to use for the next conversation."""
return self._attr_pipeline_entity_id
@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id
async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
Returns the detected wake word phrase or None.
"""
if self._wake_word_intercept_future is not None:
raise SatelliteBusyError("Wake word interception already in progress")
# Will cause next wake word to be intercepted in
# async_accept_pipeline_from_satellite
self._wake_word_intercept_future = asyncio.Future()
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
try:
return await self._wake_word_intercept_future
finally:
self._wake_word_intercept_future = None
async def async_internal_announce(
self,
message: str | None = None,
media_id: str | None = None,
) -> None:
"""Play an announcement on the satellite.
If media_id is not provided, message is synthesized to
audio with the selected pipeline.
Calls async_announce with media id.
"""
if message is None:
message = ""
if not media_id:
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)
tts_options: dict[str, Any] = {}
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
media_id = tts_generate_media_source_id(
self.hass,
message,
engine=pipeline.tts_engine,
language=pipeline.tts_language,
options=tts_options,
)
if media_source.is_media_source_id(media_id):
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url
# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)
if self._is_announcing:
raise SatelliteBusyError
self._is_announcing = True
try:
# Block until announcement is finished
await self.async_announce(message, media_id)
finally:
self._is_announcing = False
async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
"""
raise NotImplementedError
async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
if self._wake_word_intercept_future and start_stage in (
PipelineStage.WAKE_WORD,
PipelineStage.STT,
):
if start_stage == PipelineStage.WAKE_WORD:
self._wake_word_intercept_future.set_exception(
AssistSatelliteError(
"Only on-device wake words currently supported"
)
)
return
# Intercepting wake word and immediately end pipeline
_LOGGER.debug(
"Intercepted wake word: %s (entity_id=%s)",
wake_word_phrase,
self.entity_id,
)
if wake_word_phrase is None:
self._wake_word_intercept_future.set_exception(
AssistSatelliteError("No wake word phrase provided")
)
else:
self._wake_word_intercept_future.set_result(wake_word_phrase)
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return
device_id = self.registry_entry.device_id if self.registry_entry else None
# Refresh context if necessary
if (
(self._context is None)
or (self._context_set is None)
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
):
self.async_set_context(Context())
assert self._context is not None
# Reset conversation id if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = ulid.ulid()
# Update timeout
self._conversation_id_time = time.monotonic()
# Set entity state based on pipeline events
self._run_has_tts = False
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
)
@abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
@callback
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type is PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type is PipelineEventType.RUN_END:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
self.on_pipeline_event(event)
@callback
def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state
self.async_write_ha_state()
@callback
def tts_response_finished(self) -> None:
"""Tell entity that the text-to-speech response has finished playing."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
@callback
def _resolve_pipeline(self) -> str | None:
"""Resolve pipeline from select entity to id."""
if not (pipeline_entity_id := self.pipeline_entity_id):
return None
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
raise RuntimeError("Pipeline entity not found")
if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
return pipeline.id
return None
@callback
def _resolve_vad_sensitivity(self) -> float:
"""Resolve VAD sensitivity from select entity to enum."""
vad_sensitivity = vad.VadSensitivity.DEFAULT
if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise RuntimeError("VAD sensitivity entity not found")
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
return vad.VadSensitivity.to_seconds(vad_sensitivity)

View file

@ -0,0 +1,11 @@
"""Errors for assist satellite."""
from homeassistant.exceptions import HomeAssistantError
class AssistSatelliteError(HomeAssistantError):
"""Base class for assist satellite errors."""
class SatelliteBusyError(AssistSatelliteError):
"""Satellite is busy and cannot handle the request."""

View file

@ -0,0 +1,12 @@
{
"entity_component": {
"_": {
"default": "mdi:account-voice"
}
},
"services": {
"announce": {
"service": "mdi:bullhorn"
}
}
}

View file

@ -0,0 +1,9 @@
{
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@synesthesiam"],
"config_flow": false,
"dependencies": ["assist_pipeline", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity"
}

View file

@ -0,0 +1,26 @@
"""Models for assist satellite."""
from enum import IntFlag, StrEnum
class AssistSatelliteState(StrEnum):
"""Valid states of an Assist satellite entity."""
LISTENING_WAKE_WORD = "listening_wake_word"
"""Device is streaming audio for wake word detection to Home Assistant."""
LISTENING_COMMAND = "listening_command"
"""Device is streaming audio with the voice command to Home Assistant."""
PROCESSING = "processing"
"""Home Assistant is processing the voice command."""
RESPONDING = "responding"
"""Device is speaking the response."""
class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""
ANNOUNCE = 1
"""Device supports remotely triggered announcements."""

View file

@ -0,0 +1,16 @@
announce:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
fields:
message:
required: false
example: "Time to wake up!"
selector:
text:
media_id:
required: false
selector:
text:

View file

@ -0,0 +1,30 @@
{
"title": "Assist satellite",
"entity_component": {
"_": {
"name": "Assist satellite",
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
},
"services": {
"announce": {
"name": "Announce",
"description": "Let the satellite announce a message.",
"fields": {
"message": {
"name": "Message",
"description": "The message to announce."
},
"media_id": {
"name": "Media ID",
"description": "The media ID to announce instead of using text-to-speech."
}
}
}
}
}

View file

@ -0,0 +1,46 @@
"""Assist satellite Websocket API."""
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from .const import DOMAIN
from .entity import AssistSatelliteEntity
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/intercept_wake_word",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_intercept_wake_word(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})

View file

@ -0,0 +1,511 @@
"""Support for assist satellites in ESPHome."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable
from functools import partial
import io
import logging
import socket
from typing import Any, cast
import wave
from aioesphomeapi import (
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineStage,
)
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import EsphomeAssistEntity
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
}
)
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
EsphomeEnumMapper(
{
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
}
)
)
async def async_setup_entry(
hass: HomeAssistant,
entry: ESPHomeConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Assist satellite entity."""
entry_data = entry.runtime_data
assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities(
[
EsphomeAssistSatellite(hass, entry, entry_data),
]
)
class EsphomeAssistSatellite(
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
):
"""Satellite running ESPHome."""
entity_description = assist_satellite.AssistSatelliteEntityDescription(
key="assist_satellite",
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
entry_data: RuntimeEntryData,
) -> None:
"""Initialize satellite."""
super().__init__(entry_data)
self.hass = hass
self.config_entry = config_entry
self.entry_data = entry_data
self.cli = self.entry_data.client
self._is_running: bool = True
self._pipeline_task: asyncio.Task | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._tts_streaming_task: asyncio.Task | None = None
self._udp_server: VoiceAssistantUDPServer | None = None
@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
assert self.entry_data.device_info is not None
ent_reg = er.async_get(self.hass)
return ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-pipeline",
)
@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
assert self.entry_data.device_info is not None
ent_reg = er.async_get(self.hass)
return ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
)
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.API_AUDIO:
# TCP audio
self.entry_data.disconnect_callbacks.add(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
)
)
else:
# UDP audio
self.entry_data.disconnect_callbacks.add(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
)
)
if feature_flags & VoiceAssistantFeature.TIMERS:
# Device supports timers
assert (self.registry_entry is not None) and (
self.registry_entry.device_id is not None
)
self.entry_data.disconnect_callbacks.add(
async_register_timer_handler(
self.hass, self.registry_entry.device_id, self.handle_timer_event
)
)
self._set_state(assist_satellite.AssistSatelliteState.LISTENING_WAKE_WORD)
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
self._is_running = False
self._stop_pipeline()
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
return
data_to_send: dict[str, Any] = {}
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
self.entry_data.async_set_assist_pipeline_state(True)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
assert event.data is not None
data_to_send = {
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
if tts_output := event.data["tts_output"]:
path = tts_output["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.SPEAKER:
media_id = tts_output["media_id"]
self._tts_streaming_task = (
self.config_entry.async_create_background_task(
self.hass,
self._stream_tts_audio(media_id),
"esphome_voice_assistant_tts",
)
)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert event.data is not None
if not event.data["wake_word_output"]:
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
data_to_send = {
"code": "no_wake_word",
"message": "No wake word detected",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
self.cli.send_voice_assistant_event(event_type, data_to_send)
async def handle_pipeline_start(
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
"""Handle pipeline run request."""
# Clear audio queue
while not self._audio_queue.empty():
await self._audio_queue.get()
if self._tts_streaming_task is not None:
# Cancel current TTS response
self._tts_streaming_task.cancel()
self._tts_streaming_task = None
# API or UDP output audio
port: int = 0
assert self.entry_data.device_info is not None
feature_flags = (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
)
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
feature_flags & VoiceAssistantFeature.API_AUDIO
):
port = await self._start_udp_server()
_LOGGER.debug("Started UDP server on port %s", port)
# Device triggered pipeline (wake word, etc.)
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
end_stage = PipelineStage.TTS
# Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self.entry_data.async_set_assist_pipeline_state(True)
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self.async_accept_pipeline_from_satellite(
audio_stream=self._wrap_audio_stream(),
start_stage=start_stage,
end_stage=end_stage,
wake_word_phrase=wake_word_phrase,
),
"esphome_assist_satellite_pipeline",
)
self._pipeline_task.add_done_callback(
lambda _future: self.handle_pipeline_finished()
)
return port
async def handle_audio(self, data: bytes) -> None:
"""Handle incoming audio chunk from API."""
self._audio_queue.put_nowait(data)
async def handle_pipeline_stop(self) -> None:
"""Handle request for pipeline to stop."""
self._stop_pipeline()
def handle_pipeline_finished(self) -> None:
"""Handle when pipeline has finished running."""
self.entry_data.async_set_assist_pipeline_state(False)
self._stop_udp_server()
_LOGGER.debug("Pipeline finished")
def handle_timer_event(
self, event_type: TimerEventType, timer_info: TimerInfo
) -> None:
"""Handle timer events."""
try:
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
except KeyError:
_LOGGER.debug("Received unknown timer event type: %s", event_type)
return
self.cli.send_voice_assistant_timer_event(
native_event_type,
timer_info.id,
timer_info.name,
timer_info.created_seconds,
timer_info.seconds_left,
timer_info.is_active,
)
async def _stream_tts_audio(
self,
media_id: str,
sample_rate: int = 16000,
sample_width: int = 2,
sample_channels: int = 1,
samples_per_chunk: int = 512,
) -> None:
"""Stream TTS audio chunks to device via API or UDP."""
self.cli.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
try:
if not self._is_running:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
return
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
or (wav_file.getnchannels() != sample_channels)
):
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
return
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
while self._is_running:
chunk = wav_file.readframes(samples_per_chunk)
if not chunk:
break
if self._udp_server is not None:
self._udp_server.send_audio_bytes(chunk)
else:
self.cli.send_voice_assistant_audio(chunk)
# Wait for 90% of the duration of the audio that was
# sent for it to be played. This will overrun the
# device's buffer for very long audio, so using a media
# player is preferred.
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
seconds_in_chunk = samples_in_chunk / sample_rate
await asyncio.sleep(seconds_in_chunk * 0.9)
except asyncio.CancelledError:
return # Don't trigger state change
finally:
self.cli.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
# State change
self.tts_response_finished()
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
"""Yield audio chunks from the queue until None."""
while True:
chunk = await self._audio_queue.get()
if not chunk:
break
yield chunk
def _stop_pipeline(self) -> None:
"""Request pipeline to be stopped."""
self._audio_queue.put_nowait(None)
_LOGGER.debug("Requested pipeline stop")
async def _start_udp_server(self) -> int:
"""Start a UDP server on a random free port."""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", 0)) # random free port
(
_transport,
protocol,
) = await asyncio.get_running_loop().create_datagram_endpoint(
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
)
assert isinstance(protocol, VoiceAssistantUDPServer)
self._udp_server = protocol
# Return port
return cast(int, sock.getsockname()[1])
def _stop_udp_server(self) -> None:
"""Stop the UDP server if it's running."""
if self._udp_server is None:
return
try:
self._udp_server.close()
finally:
self._udp_server = None
_LOGGER.debug("Stopped UDP server")
# -----------------------------------------------------------------------------
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the audio queue."""
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
def __init__(
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
) -> None:
"""Initialize protocol."""
super().__init__(*args, **kwargs)
self._audio_queue = audio_queue
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if self.remote_addr is None:
self.remote_addr = addr
self._audio_queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
# Stop pipeline
self._audio_queue.put_nowait(None)
def close(self) -> None:
"""Close the receiver."""
if self.transport is not None:
self.transport.close()
self.remote_addr = None
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via UDP."""
if self.transport is None:
_LOGGER.error("No transport to send audio to")
return
if self.remote_addr is None:
_LOGGER.error("No address to send audio to")
return
self.transport.sendto(data, self.remote_addr)

View file

@ -20,19 +20,17 @@ from aioesphomeapi import (
RequiresEncryptionAPIError,
UserService,
UserServiceArgType,
VoiceAssistantAudioSettings,
VoiceAssistantFeature,
)
from awesomeversion import AwesomeVersion
import voluptuous as vol
from homeassistant.components import tag, zeroconf
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.const import (
ATTR_DEVICE_ID,
CONF_MODE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_LOGGING_CHANGED,
Platform,
)
from homeassistant.core import (
Event,
@ -73,12 +71,6 @@ from .domain_data import DomainData
# Import config flow so that it's added to the registry
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
from .voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantPipeline,
VoiceAssistantUDPPipeline,
handle_timer_event,
)
_LOGGER = logging.getLogger(__name__)
@ -149,7 +141,6 @@ class ESPHomeManager:
"cli",
"device_id",
"domain_data",
"voice_assistant_pipeline",
"reconnect_logic",
"zeroconf_instance",
"entry_data",
@ -173,7 +164,6 @@ class ESPHomeManager:
self.cli = cli
self.device_id: str | None = None
self.domain_data = domain_data
self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None
self.reconnect_logic: ReconnectLogic | None = None
self.zeroconf_instance = zeroconf_instance
self.entry_data = entry.runtime_data
@ -338,77 +328,6 @@ class ESPHomeManager:
entity_id, attribute, self.hass.states.get(entity_id)
)
def _handle_pipeline_finished(self) -> None:
self.entry_data.async_set_assist_pipeline_state(False)
if self.voice_assistant_pipeline is not None:
if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline):
self.voice_assistant_pipeline.close()
self.voice_assistant_pipeline = None
async def _handle_pipeline_start(
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int | None:
"""Start a voice assistant pipeline."""
if self.voice_assistant_pipeline is not None:
_LOGGER.warning("Previous Voice assistant pipeline was not stopped")
self.voice_assistant_pipeline.stop()
self.voice_assistant_pipeline = None
hass = self.hass
assert self.entry_data.device_info is not None
if (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.API_AUDIO
):
self.voice_assistant_pipeline = VoiceAssistantAPIPipeline(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
self.cli,
)
port = 0
else:
self.voice_assistant_pipeline = VoiceAssistantUDPPipeline(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
)
port = await self.voice_assistant_pipeline.start_server()
assert self.device_id is not None, "Device ID must be set"
hass.async_create_background_task(
self.voice_assistant_pipeline.run_pipeline(
device_id=self.device_id,
conversation_id=conversation_id or None,
flags=flags,
audio_settings=audio_settings,
wake_word_phrase=wake_word_phrase,
),
"esphome.voice_assistant_pipeline.run_pipeline",
)
return port
async def _handle_pipeline_stop(self) -> None:
"""Stop a voice assistant pipeline."""
if self.voice_assistant_pipeline is not None:
self.voice_assistant_pipeline.stop()
async def _handle_audio(self, data: bytes) -> None:
if self.voice_assistant_pipeline is None:
return
assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline)
self.voice_assistant_pipeline.receive_audio_bytes(data)
async def on_connect(self) -> None:
"""Subscribe to states and list entities on successful API login."""
try:
@ -509,29 +428,14 @@ class ESPHomeManager:
)
)
flags = device_info.voice_assistant_feature_flags_compat(api_version)
if flags:
if flags & VoiceAssistantFeature.API_AUDIO:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
handle_audio=self._handle_audio,
)
)
else:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
)
)
if flags & VoiceAssistantFeature.TIMERS:
entry_data.disconnect_callbacks.add(
async_register_timer_handler(
hass, self.device_id, partial(handle_timer_event, cli)
)
)
if device_info.voice_assistant_feature_flags_compat(api_version) and (
Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms
):
# Create assist satellite entity
await self.hass.config_entries.async_forward_entry_setups(
self.entry, [Platform.ASSIST_SATELLITE]
)
entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE)
cli.subscribe_states(entry_data.async_update_state)
cli.subscribe_service_calls(self.async_on_service_call)

View file

@ -1,479 +0,0 @@
"""ESPHome voice assistant support."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
import io
import logging
import socket
from typing import cast
import wave
from aioesphomeapi import (
APIClient,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
PipelineStage,
WakeWordSettings,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.error import (
WakeWordDetectionAborted,
WakeWordDetectionError,
)
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import Context, HomeAssistant, callback
from .const import DOMAIN
from .entry_data import RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
UDP_MAX_PACKET_SIZE = 1024
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
}
)
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
EsphomeEnumMapper(
{
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
}
)
)
class VoiceAssistantPipeline:
"""Base abstract pipeline class."""
started = False
stop_requested = False
def __init__(
self,
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
) -> None:
"""Initialize the pipeline."""
self.context = Context()
self.hass = hass
self.entry_data = entry_data
assert entry_data.device_info is not None
self.device_info = entry_data.device_info
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None
@property
def is_running(self) -> bool:
"""True if the pipeline is started and hasn't been asked to stop."""
return self.started and (not self.stop_requested)
async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets."""
while data := await self.queue.get():
if not self.is_running:
break
yield data
def _event_callback(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
return
data_to_send = None
error = False
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
self.entry_data.async_set_assist_pipeline_state(True)
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
assert event.data is not None
data_to_send = {
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
tts_output = event.data["tts_output"]
if tts_output:
path = tts_output["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
if (
self.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.SPEAKER
):
media_id = tts_output["media_id"]
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
else:
self._tts_done.set()
else:
# Empty TTS response
data_to_send = {}
self._tts_done.set()
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert event.data is not None
if not event.data["wake_word_output"]:
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
data_to_send = {
"code": "no_wake_word",
"message": "No wake word detected",
}
error = True
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
error = True
self.handle_event(event_type, data_to_send)
if error:
self._tts_done.set()
self.handle_finished()
async def run_pipeline(
self,
device_id: str,
conversation_id: str | None,
flags: int = 0,
audio_settings: VoiceAssistantAudioSettings | None = None,
wake_word_phrase: str | None = None,
) -> None:
"""Run the Voice Assistant pipeline."""
if audio_settings is None or audio_settings.volume_multiplier == 0:
audio_settings = VoiceAssistantAudioSettings()
if (
self.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.SPEAKER
):
tts_audio_output = "wav"
else:
tts_audio_output = "mp3"
_LOGGER.debug("Starting pipeline")
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
try:
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=self._iterate_packets(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.device_info.mac_address
),
conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output,
start_stage=start_stage,
wake_word_settings=WakeWordSettings(timeout=5),
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
noise_suppression_level=audio_settings.noise_suppression_level,
auto_gain_dbfs=audio_settings.auto_gain,
volume_multiplier=audio_settings.volume_multiplier,
is_vad_enabled=bool(flags & VoiceAssistantCommandFlag.USE_VAD),
silence_seconds=VadSensitivity.to_seconds(
pipeline_select.get_vad_sensitivity(
self.hass, DOMAIN, self.device_info.mac_address
)
),
),
)
# Block until TTS is done sending
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound as e:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": e.code,
"message": e.message,
},
)
_LOGGER.warning("Pipeline not found")
except WakeWordDetectionAborted:
pass # Wake word detection was aborted and `handle_finished` is enough.
except WakeWordDetectionError as e:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": e.code,
"message": e.message,
},
)
finally:
self.handle_finished()
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to device via UDP."""
# Always send stream start/end events
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
try:
if not self.is_running:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
audio_bytes_size = len(audio_bytes)
_LOGGER.debug("Sending %d bytes of audio", audio_bytes_size)
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
sample_offset = 0
samples_left = audio_bytes_size // bytes_per_sample
while (samples_left > 0) and self.is_running:
bytes_offset = sample_offset * bytes_per_sample
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
samples_in_chunk = len(chunk) // bytes_per_sample
samples_left -= samples_in_chunk
self.send_audio_bytes(chunk)
await asyncio.sleep(
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
)
sample_offset += samples_in_chunk
finally:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
self._tts_task = None
self._tts_done.set()
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device."""
raise NotImplementedError
def stop(self) -> None:
"""Stop the pipeline."""
self.queue.put_nowait(b"")
class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline):
"""Receive UDP packets and forward them to the voice assistant."""
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPPipeline:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stop_requested:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.is_running:
return
if self.remote_addr is None:
self.remote_addr = addr
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
self.handle_finished()
@callback
def stop(self) -> None:
"""Stop the receiver."""
super().stop()
self.close()
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stop_requested = True
if self.transport is not None:
self.transport.close()
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via UDP."""
if self.transport is None:
_LOGGER.error("No transport to send audio to")
return
self.transport.sendto(data, self.remote_addr)
class VoiceAssistantAPIPipeline(VoiceAssistantPipeline):
"""Send audio to the voice assistant via the API."""
def __init__(
self,
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
api_client: APIClient,
) -> None:
"""Initialize the pipeline."""
super().__init__(hass, entry_data, handle_event, handle_finished)
self.api_client = api_client
self.started = True
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via the API."""
self.api_client.send_voice_assistant_audio(data)
@callback
def receive_audio_bytes(self, data: bytes) -> None:
"""Receive audio bytes from the device."""
if not self.is_running:
return
self.queue.put_nowait(data)
@callback
def stop(self) -> None:
"""Stop the pipeline."""
super().stop()
self.started = False
self.stop_requested = True
def handle_timer_event(
api_client: APIClient, event_type: TimerEventType, timer_info: TimerInfo
) -> None:
"""Handle timer events."""
try:
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
except KeyError:
_LOGGER.debug("Received unknown timer event type: %s", event_type)
return
api_client.send_voice_assistant_timer_event(
native_event_type,
timer_info.id,
timer_info.name,
timer_info.created_seconds,
timer_info.seconds_left,
timer_info.is_active,
)

View file

@ -20,6 +20,7 @@ from .devices import VoIPDevices
from .voip import HassVoipDatagramProtocol
PLATFORMS = (
Platform.ASSIST_SATELLITE,
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,

View file

@ -0,0 +1,312 @@
"""Assist satellite entity for VoIP integration."""
from __future__ import annotations
import asyncio
from enum import IntFlag
from functools import partial
import io
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Final
import wave
from voip_utils import RtpDatagramProtocol
from homeassistant.components import tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteState,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.async_ import queue_to_iterable
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity
if TYPE_CHECKING:
from . import DomainData
_LOGGER = logging.getLogger(__name__)
_PIPELINE_TIMEOUT_SEC: Final = 30
class Tones(IntFlag):
"""Feedback tones for specific events."""
LISTENING = 1
PROCESSING = 2
ERROR = 4
_TONE_FILENAMES: dict[Tones, str] = {
Tones.LISTENING: "tone.pcm",
Tones.PROCESSING: "processing.pcm",
Tones.ERROR: "error.pcm",
}
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP Assist satellite entity."""
domain_data: DomainData = hass.data[DOMAIN]
@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
domain_data.devices.async_add_new_device_listener(async_add_device)
entities: list[VoIPEntity] = [
VoipAssistSatellite(hass, device, config_entry)
for device in domain_data.devices
]
async_add_entities(entities)
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
"""Assist satellite for VoIP devices."""
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_has_entity_name = True
_attr_name = None
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
def __init__(
self,
hass: HomeAssistant,
voip_device: VoIPDevice,
config_entry: ConfigEntry,
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
) -> None:
"""Initialize an Assist satellite."""
VoIPEntity.__init__(self, voip_device)
AssistSatelliteEntity.__init__(self)
RtpDatagramProtocol.__init__(self)
self.config_entry = config_entry
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._audio_chunk_timeout: float = 2.0
self._pipeline_task: asyncio.Task | None = None
self._pipeline_had_error: bool = False
self._tts_done = asyncio.Event()
self._tts_extra_timeout: float = 1.0
self._tone_bytes: dict[Tones, bytes] = {}
self._tones = tones
self._processing_tone_done = asyncio.Event()
@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
return self.voip_device.get_pipeline_entity_id(self.hass)
@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
return self.voip_device.get_vad_sensitivity_entity_id(self.hass)
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
self.voip_device.protocol = self
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
assert self.voip_device.protocol == self
self.voip_device.protocol = None
# -------------------------------------------------------------------------
# VoIP
# -------------------------------------------------------------------------
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
self.voip_device.set_is_active(True)
# Play listening tone at the start of each cycle
await self._play_tone(Tones.LISTENING, silence_before=0.2)
try:
self._tts_done.clear()
# Run pipeline with a timeout
_LOGGER.debug("Starting pipeline")
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
await self.async_accept_pipeline_from_satellite(
audio_stream=queue_to_iterable(
self._audio_queue, timeout=self._audio_chunk_timeout
),
)
if self._pipeline_had_error:
self._pipeline_had_error = False
await self._play_tone(Tones.ERROR)
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except (asyncio.CancelledError, TimeoutError):
# Expected after caller hangs up
_LOGGER.debug("Pipeline cancelled or timed out")
self.disconnect()
self._clear_audio_queue()
finally:
self.voip_device.set_is_active(False)
# Allow pipeline to run again
self._pipeline_task = None
def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type == PipelineEventType.STT_END:
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
self._processing_tone_done.clear()
self.config_entry.async_create_background_task(
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
)
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.config_entry.async_create_background_task(
self.hass,
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS when pipeline is finished.
self._pipeline_had_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return # not connected
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
# Don't overlap TTS and processing beep
await self._processing_tone_done.wait()
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
# Update satellite state
self.tts_response_finished()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
"""Play a tone as feedback to the user if it's enabled."""
if (self._tones & tone) != tone:
return # not enabled
if tone not in self._tone_bytes:
# Do I/O in executor
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
self._load_pcm,
_TONE_FILENAMES[tone],
)
await self._async_send_audio(
self._tone_bytes[tone],
silence_before=silence_before,
)
if tone == Tones.PROCESSING:
self._processing_tone_done.set()
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()

View file

@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
"""Call when entity about to be added to hass."""
await super().async_added_to_hass()
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
self.async_on_remove(
self.voip_device.async_listen_update(self._is_active_changed)
)
@callback
def _is_active_changed(self, device: VoIPDevice) -> None:
"""Call when active state changed."""
self._attr_is_on = self._device.is_active
self._attr_is_on = self.voip_device.is_active
self.async_write_ha_state()

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterator
from dataclasses import dataclass, field
from voip_utils import CallInfo
from voip_utils import CallInfo, VoipDatagramProtocol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Event, HomeAssistant, callback
@ -22,6 +22,7 @@ class VoIPDevice:
device_id: str
is_active: bool = False
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
protocol: VoipDatagramProtocol | None = None
@callback
def set_is_active(self, active: bool) -> None:
@ -56,6 +57,18 @@ class VoIPDevice:
return False
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for pipeline select."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for VAD sensitivity."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
)
class VoIPDevices:
"""Class to store devices."""

View file

@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
_attr_has_entity_name = True
_attr_should_poll = False
def __init__(self, device: VoIPDevice) -> None:
def __init__(self, voip_device: VoIPDevice) -> None:
"""Initialize VoIP entity."""
self._device = device
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
self.voip_device = voip_device
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, device.voip_id)},
identifiers={(DOMAIN, voip_device.voip_id)},
)

View file

@ -3,7 +3,7 @@
"name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline"],
"dependencies": ["assist_pipeline", "assist_satellite"],
"documentation": "https://www.home-assistant.io/integrations/voip",
"iot_class": "local_push",
"quality_scale": "internal",

View file

@ -10,6 +10,16 @@
}
},
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
}
}
},
"binary_sensor": {
"call_in_progress": {
"name": "Call in progress"

View file

@ -3,15 +3,11 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable, MutableSequence, Sequence
from functools import partial
import io
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING
import wave
from voip_utils import (
CallInfo,
@ -21,33 +17,19 @@ from voip_utils import (
VoipDatagramProtocol,
)
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline import (
Pipeline,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
async_get_pipeline,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.audio_enhancer import (
AudioEnhancer,
MicroVadSpeexEnhancer,
)
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
from homeassistant.util.ulid import ulid_now
from homeassistant.core import HomeAssistant
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
from .devices import VoIPDevices
_LOGGER = logging.getLogger(__name__)
@ -60,11 +42,8 @@ def make_protocol(
) -> VoipDatagramProtocol:
"""Plays a pre-recorded message if pipeline is misconfigured."""
voip_device = devices.async_get_or_create(call_info)
pipeline_id = pipeline_select.get_chosen_pipeline(
hass,
DOMAIN,
voip_device.voip_id,
)
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
try:
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
except PipelineNotFound:
@ -83,22 +62,18 @@ def make_protocol(
rtcp_state=rtcp_state,
)
vad_sensitivity = pipeline_select.get_vad_sensitivity(
hass,
DOMAIN,
voip_device.voip_id,
)
if (protocol := voip_device.protocol) is None:
raise ValueError("VoIP satellite not found")
# Pipeline is properly configured
return PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(user_id=devices.config_entry.data["user"]),
opus_payload_type=call_info.opus_payload_type,
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
rtcp_state=rtcp_state,
)
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol.rtcp_state = rtcp_state
if protocol.rtcp_state is not None:
# Automatically disconnect when BYE is received over RTCP
protocol.rtcp_state.bye_callback = protocol.disconnect
return protocol
class HassVoipDatagramProtocol(VoipDatagramProtocol):
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
await self._closed_event.wait()
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
"""Run a voice assistant pipeline in a loop for a VoIP call."""
def __init__(
self,
hass: HomeAssistant,
language: str,
voip_device: VoIPDevice,
context: Context,
opus_payload_type: int,
pipeline_timeout: float = 30.0,
audio_timeout: float = 2.0,
buffered_chunks_before_speech: int = 100,
listening_tone_enabled: bool = True,
processing_tone_enabled: bool = True,
error_tone_enabled: bool = True,
tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0,
silence_seconds: float = 1.0,
rtcp_state: RtcpState | None = None,
) -> None:
"""Set up pipeline RTP server."""
super().__init__(
rate=RATE,
width=WIDTH,
channels=CHANNELS,
opus_payload_type=opus_payload_type,
rtcp_state=rtcp_state,
)
self.hass = hass
self.language = language
self.voip_device = voip_device
self.pipeline: Pipeline | None = None
self.pipeline_timeout = pipeline_timeout
self.audio_timeout = audio_timeout
self.buffered_chunks_before_speech = buffered_chunks_before_speech
self.listening_tone_enabled = listening_tone_enabled
self.processing_tone_enabled = processing_tone_enabled
self.error_tone_enabled = error_tone_enabled
self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout
self.silence_seconds = silence_seconds
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context
self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
self._tts_done = asyncio.Event()
self._session_id: str | None = None
self._tone_bytes: bytes | None = None
self._processing_bytes: bytes | None = None
self._error_bytes: bytes | None = None
self._pipeline_error: bool = False
def connection_made(self, transport):
"""Server is ready."""
super().connection_made(transport)
self.voip_device.set_is_active(True)
def connection_lost(self, exc):
"""Handle connection is lost or closed."""
super().connection_lost(exc)
self.voip_device.set_is_active(False)
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.hass.async_create_background_task(
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
if self._session_id is None:
self._session_id = ulid_now()
# Play listening tone at the start of each cycle
if self.listening_tone_enabled:
await self._play_listening_tone()
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
audio_enhancer = MicroVadSpeexEnhancer(0, 0, True)
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)
speech_detected = await self._wait_for_speech(
segmenter,
audio_enhancer,
chunk_buffer,
)
if not speech_detected:
_LOGGER.debug("No speech detected")
return
_LOGGER.debug("Starting pipeline")
self._tts_done.clear()
async def stt_stream():
try:
async for chunk in self._segment_audio(
segmenter,
audio_enhancer,
chunk_buffer,
):
yield chunk
if self.processing_tone_enabled:
await self._play_processing_tone()
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Audio timeout")
self._session_id = None
self.disconnect()
finally:
self._clear_audio_queue()
# Run pipeline with a timeout
async with asyncio.timeout(self.pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="wav",
)
if self._pipeline_error:
self._pipeline_error = False
if self.error_tone_enabled:
await self._play_error_tone()
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Pipeline timeout")
self._session_id = None
self.disconnect()
finally:
# Allow pipeline to run again
self._pipeline_task = None
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: MutableSequence[bytes],
):
"""Buffer audio chunks until speech is detected.
Returns True if speech was detected, False otherwise.
"""
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
chunk_buffer.append(chunk)
segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
)
if segmenter.in_command:
# Buffer until command starts
if len(vad_buffer) > 0:
chunk_buffer.append(vad_buffer.bytes())
return True
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
return False
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished."""
# Buffered chunks first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
if not segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
):
# Voice command is finished
break
yield chunk
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
def _clear_audio_queue(self) -> None:
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def _event_callback(self, event: PipelineEvent):
if not event.data:
return
if event.type == PipelineEventType.INTENT_END:
# Capture conversation id
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
tts_output = event.data["tts_output"]
if tts_output:
media_id = tts_output["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS
self._pipeline_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_listening_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is listening."""
if self._tone_bytes is None:
# Do I/O in executor
self._tone_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"tone.pcm",
)
await self._async_send_audio(
self._tone_bytes,
silence_before=self.tone_delay,
)
async def _play_processing_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is processing the voice command."""
if self._processing_bytes is None:
# Do I/O in executor
self._processing_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"processing.pcm",
)
await self._async_send_audio(self._processing_bytes)
async def _play_error_tone(self) -> None:
"""Play a tone to indicate a pipeline error occurred."""
if self._error_bytes is None:
# Do I/O in executor
self._error_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"error.pcm",
)
await self._async_send_audio(self._error_bytes)
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()
class PreRecordMessageProtocol(RtpDatagramProtocol):
"""Plays a pre-recorded message on a loop."""

View file

@ -41,6 +41,7 @@ class Platform(StrEnum):
AIR_QUALITY = "air_quality"
ALARM_CONTROL_PANEL = "alarm_control_panel"
ASSIST_SATELLITE = "assist_satellite"
BINARY_SENSOR = "binary_sensor"
BUTTON = "button"
CALENDAR = "calendar"

View file

@ -5,22 +5,28 @@ from __future__ import annotations
from asyncio import (
AbstractEventLoop,
Future,
Queue,
Semaphore,
Task,
TimerHandle,
gather,
get_running_loop,
timeout as async_timeout,
)
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
import concurrent.futures
import logging
import threading
from typing import Any
from typing_extensions import TypeVar
_LOGGER = logging.getLogger(__name__)
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
_DataT = TypeVar("_DataT", default=Any)
def create_eager_task[_T](
coro: Coroutine[Any, Any, _T],
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
"""Return a list of scheduled TimerHandles."""
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
return handles
async def queue_to_iterable(
queue: Queue[_DataT], timeout: float | None = None
) -> AsyncIterable[_DataT]:
"""Stream items from a queue until None with an optional timeout per item."""
if timeout is None:
while (item := await queue.get()) is not None:
yield item
else:
async with async_timeout(timeout):
item = await queue.get()
while item is not None:
yield item
async with async_timeout(timeout):
item = await queue.get()

View file

@ -0,0 +1,3 @@
"""Tests for Assist Satellite."""
ENTITY_ID = "assist_satellite.test_entity"

View file

@ -0,0 +1,116 @@
"""Test helpers for Assist Satellite."""
from unittest.mock import Mock
import pytest
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
from tests.components.tts.conftest import (
mock_tts_cache_dir_fixture_autouse, # noqa: F401
)
TEST_DOMAIN = "test_satellite"
class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""
_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)
async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on a device."""
self.announcements.append((message, media_id))
@pytest.fixture
def entity() -> MockAssistSatellite:
"""Mock Assist Satellite Entity."""
return MockAssistSatellite()
@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Mock config entry."""
entry = MockConfigEntry(domain=TEST_DOMAIN)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_components(
hass: HomeAssistant,
config_entry: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test satellite platform via config entry."""
async_add_entities([entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry

View file

@ -0,0 +1,149 @@
"""Test the Assist Satellite entity."""
from unittest.mock import patch
import pytest
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_update_pipeline,
vad,
)
from homeassistant.components.assist_satellite import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import Context, HomeAssistant
from . import ENTITY_ID
from .conftest import MockAssistSatellite
async def test_entity_state(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test entity state represent events."""
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == STATE_UNKNOWN
context = Context()
audio_stream = object()
entity.async_set_context(context)
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
) as mock_start_pipeline:
await entity.async_accept_pipeline_from_satellite(audio_stream)
assert mock_start_pipeline.called
kwargs = mock_start_pipeline.call_args[1]
assert kwargs["context"] is context
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav"
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
)
assert kwargs["start_stage"] == PipelineStage.STT
assert kwargs["end_stage"] == PipelineStage.TTS
for event_type, expected_state in (
(PipelineEventType.RUN_START, STATE_UNKNOWN),
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
):
kwargs["event_callback"](PipelineEvent(event_type, {}))
state = hass.states.get(ENTITY_ID)
assert state.state == expected_state, event_type
entity.tts_response_finished()
state = hass.states.get(ENTITY_ID)
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
@pytest.mark.parametrize(
("service_data", "expected_params"),
[
(
{"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"),
),
(
{
"message": "Hello",
"media_id": "http://example.com/bla.mp3",
},
("Hello", "http://example.com/bla.mp3"),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
),
],
)
async def test_announce(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
service_data: dict,
expected_params: tuple[str, str],
) -> None:
"""Test announcing on a device."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
tts_engine="tts.mock_entity",
tts_language="en",
)
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
),
):
await hass.services.async_call(
"assist_satellite",
"announce",
service_data,
target={"entity_id": "assist_satellite.test_entity"},
blocking=True,
)
assert entity.announcements[0] == expected_params

View file

@ -0,0 +1,192 @@
"""Test WebSocket API."""
import asyncio
from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import ENTITY_ID
from .conftest import MockAssistSatellite
from tests.common import MockUser
from tests.typing import WebSocketGenerator
async def test_intercept_wake_word(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
await entity.async_accept_pipeline_from_satellite(
object(),
start_stage=PipelineStage.STT,
wake_word_phrase="ok, nabu",
)
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
async def test_intercept_wake_word_requires_on_device_wake_word(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word fails if detection happens in HA."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
await entity.async_accept_pipeline_from_satellite(
object(),
# Emulate wake word processing in Home Assistant
start_stage=PipelineStage.WAKE_WORD,
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": "Only on-device wake words currently supported",
}
async def test_intercept_wake_word_requires_wake_word_phrase(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word fails if detection happens in HA."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
await entity.async_accept_pipeline_from_satellite(
object(),
start_stage=PipelineStage.STT,
# We are not passing wake word phrase
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": "No wake word phrase provided",
}
async def test_intercept_wake_word_require_admin(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_admin_user: MockUser,
) -> None:
"""Test intercepting a wake word requires admin access."""
# Remove admin permission and verify we're not allowed
hass_admin_user.groups = []
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "unauthorized",
"message": "Unauthorized",
}
async def test_intercept_wake_word_invalid_satellite(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "not_found",
"message": "Entity not found",
}
async def test_intercept_wake_word_twice(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": "Wake word interception already in progress",
}

View file

@ -20,7 +20,6 @@ from aioesphomeapi import (
ReconnectLogic,
UserService,
VoiceAssistantAudioSettings,
VoiceAssistantEventType,
VoiceAssistantFeature,
)
import pytest
@ -34,11 +33,6 @@ from homeassistant.components.esphome.const import (
DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS,
DOMAIN,
)
from homeassistant.components.esphome.entry_data import RuntimeEntryData
from homeassistant.components.esphome.voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantUDPPipeline,
)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -625,57 +619,3 @@ async def mock_esphome_device(
)
return _mock_device
@pytest.fixture
def mock_voice_assistant_api_pipeline() -> VoiceAssistantAPIPipeline:
"""Return the API Pipeline factory."""
mock_pipeline = Mock(spec=VoiceAssistantAPIPipeline)
def mock_constructor(
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
api_client: APIClient,
):
"""Fake the constructor."""
mock_pipeline.hass = hass
mock_pipeline.entry_data = entry_data
mock_pipeline.handle_event = handle_event
mock_pipeline.handle_finished = handle_finished
mock_pipeline.api_client = api_client
return mock_pipeline
mock_pipeline.side_effect = mock_constructor
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantAPIPipeline",
new=mock_pipeline,
):
yield mock_pipeline
@pytest.fixture
def mock_voice_assistant_udp_pipeline() -> VoiceAssistantUDPPipeline:
"""Return the API Pipeline factory."""
mock_pipeline = Mock(spec=VoiceAssistantUDPPipeline)
def mock_constructor(
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
):
"""Fake the constructor."""
mock_pipeline.hass = hass
mock_pipeline.entry_data = entry_data
mock_pipeline.handle_event = handle_event
mock_pipeline.handle_finished = handle_finished
return mock_pipeline
mock_pipeline.side_effect = mock_constructor
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPPipeline",
new=mock_pipeline,
):
yield mock_pipeline

View file

@ -0,0 +1,822 @@
"""Test ESPHome voice assistant server."""
import asyncio
from collections.abc import Awaitable, Callable
import io
import socket
from unittest.mock import ANY, Mock, patch
import wave
from aioesphomeapi import (
APIClient,
EntityInfo,
EntityState,
UserService,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
import pytest
from homeassistant.components import assist_satellite
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteState,
)
from homeassistant.components.esphome import DOMAIN
from homeassistant.components.esphome.assist_satellite import (
EsphomeAssistSatellite,
VoiceAssistantUDPServer,
)
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er, intent as intent_helper
import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.entity_component import EntityComponent
from .conftest import MockESPHomeDevice
def get_satellite_entity(
hass: HomeAssistant, mac_address: str
) -> EsphomeAssistSatellite | None:
"""Get the satellite entity for a device."""
ent_reg = er.async_get(hass)
satellite_entity_id = ent_reg.async_get_entity_id(
Platform.ASSIST_SATELLITE, DOMAIN, f"{mac_address}-assist_satellite"
)
if satellite_entity_id is None:
return None
component: EntityComponent[AssistSatelliteEntity] = hass.data[
assist_satellite.DOMAIN
]
if (entity := component.get_entity(satellite_entity_id)) is not None:
assert isinstance(entity, EsphomeAssistSatellite)
return entity
return None
@pytest.fixture
def mock_wav() -> bytes:
"""Return test WAV audio."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(b"test-wav")
return wav_io.getvalue()
async def test_no_satellite_without_voice_assistant(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that an assist satellite entity is not created if a voice assistant is not present."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={},
)
await hass.async_block_till_done()
# No satellite entity should be created
assert get_satellite_entity(hass, mock_device.device_info.mac_address) is None
async def test_pipeline_api_audio(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with API audio (over the TCP connection)."""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# Block TTS streaming until we're ready.
# This makes it easier to verify the order of pipeline events.
stream_tts_audio_ready = asyncio.Event()
original_stream_tts_audio = satellite._stream_tts_audio
async def _stream_tts_audio(*args, **kwargs):
await stream_tts_audio_ready.wait()
await original_stream_tts_audio(*args, **kwargs)
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == dev.id
stt_stream = kwargs["stt_stream"]
chunks = [chunk async for chunk in stt_stream]
# Verify test API audio
assert chunks == [b"test-mic"]
event_callback = kwargs["event_callback"]
# Test unknown event type
event_callback(
PipelineEvent(
type="unknown-event",
data={},
)
)
mock_client.send_voice_assistant_event.assert_not_called()
# Test error event
event_callback(
PipelineEvent(
type=PipelineEventType.ERROR,
data={"code": "test-error-code", "message": "test-error-message"},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{"code": "test-error-code", "message": "test-error-message"},
)
# Wake word
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_START,
data={
"entity_id": "test-wake-word-entity-id",
"metadata": {},
"timeout": 0,
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START,
{},
)
# Test no wake word detected
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {}}
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{"code": "no_wake_word", "message": "No wake word detected"},
)
# Correct wake word detection
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_END,
data={"wake_word_output": {"wake_word_phrase": "test-wake-word"}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END,
{},
)
# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START,
{},
)
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END,
{"text": "test-stt-text"},
)
# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START,
{},
)
assert satellite.state == AssistSatelliteState.PROCESSING
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
{"conversation_id": conversation_id},
)
# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START,
{"text": "test-tts-text"},
)
assert satellite.state == AssistSatelliteState.RESPONDING
# Should return mock_wav audio
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
{"url": media_url},
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END,
{},
)
# Allow TTS streaming to proceed
stream_tts_audio_ready.set()
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
tts_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
# Should be cleared at pipeline start
satellite._audio_queue.put_nowait(b"leftover-data")
# Should be cancelled at pipeline start
mock_tts_streaming_task = Mock()
satellite._tts_streaming_task = mock_tts_streaming_task
async with asyncio.timeout(1):
await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag.USE_WAKE_WORD,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
mock_tts_streaming_task.cancel.assert_called_once()
await satellite.handle_audio(b"test-mic")
await satellite.handle_pipeline_stop()
await pipeline_finished.wait()
await tts_finished.wait()
# Verify TTS streaming events.
# These are definitely the last two events because we blocked TTS streaming
# until after RUN_END above.
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
{},
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)
# Verify TTS WAV audio chunk came through
mock_client.send_voice_assistant_audio.assert_called_once_with(b"test-wav")
@pytest.mark.usefixtures("socket_enabled")
async def test_pipeline_udp_audio(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with legacy UDP audio.
This test is not as comprehensive as test_pipeline_api_audio since we're
mainly focused on the UDP server.
"""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
mic_audio_event = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
stt_stream = kwargs["stt_stream"]
chunks = []
async for chunk in stt_stream:
chunks.append(chunk)
mic_audio_event.set()
# Verify test UDP audio
assert chunks == [b"test-mic"]
event_callback = kwargs["event_callback"]
# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)
# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}},
)
)
# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)
# Should return mock_wav audio
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
)
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
tts_finished.set()
class TestProtocol(asyncio.DatagramProtocol):
def __init__(self) -> None:
self.transport = None
self.data_received: list[bytes] = []
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data: bytes, addr):
self.data_received.append(data)
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
async with asyncio.timeout(1):
port = await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag(0), # stt
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
assert (port is not None) and (port > 0)
(
transport,
protocol,
) = await asyncio.get_running_loop().create_datagram_endpoint(
TestProtocol, remote_addr=("127.0.0.1", port)
)
assert isinstance(protocol, TestProtocol)
# Send audio over UDP
transport.sendto(b"test-mic")
# Wait for audio chunk to be delivered
await mic_audio_event.wait()
await satellite.handle_pipeline_stop()
await pipeline_finished.wait()
await tts_finished.wait()
# Verify TTS audio (from UDP)
assert protocol.data_received == [b"test-wav"]
# Check that UDP server was stopped
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", port)) # will fail if UDP server is still running
sock.close()
async def test_udp_errors() -> None:
"""Test UDP protocol error conditions."""
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
protocol = VoiceAssistantUDPServer(audio_queue)
protocol.datagram_received(b"test", ("", 0))
assert audio_queue.qsize() == 1
assert (await audio_queue.get()) == b"test"
# None will stop the pipeline
protocol.error_received(RuntimeError())
assert audio_queue.qsize() == 1
assert (await audio_queue.get()) is None
# No transport
assert protocol.transport is None
protocol.send_audio_bytes(b"test")
# No remote address
protocol.transport = Mock()
protocol.remote_addr = None
protocol.send_audio_bytes(b"test")
protocol.transport.sendto.assert_not_called()
async def test_timer_events(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that injecting timer events results in the correct api client calls."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.TIMERS
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_called_with(
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
ANY,
"test timer",
total_seconds,
total_seconds,
True,
)
# Increase timer beyond original time and check total_seconds has increased
mock_client.send_voice_assistant_timer_event.reset_mock()
total_seconds += 5 * 60
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_INCREASE_TIMER,
{
"name": {"value": "test timer"},
"minutes": {"value": 5},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_called_with(
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
ANY,
"test timer",
total_seconds,
ANY,
True,
)
async def test_unknown_timer_event(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that unknown (new) timer event types do not result in api calls."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.TIMERS
},
)
await hass.async_block_till_done()
assert mock_device.entry.unique_id is not None
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
assert dev is not None
with patch(
"homeassistant.components.esphome.assist_satellite._TIMER_EVENT_TYPES.from_hass",
side_effect=KeyError,
):
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_not_called()
async def test_streaming_tts_errors(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test error conditions for _stream_tts_audio function."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# Should not stream if not running
satellite._is_running = False
await satellite._stream_tts_audio("test-media-id")
mock_client.send_voice_assistant_audio.assert_not_called()
satellite._is_running = True
# Should only stream WAV
async def get_mp3(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("mp3", b"")
with patch(
"homeassistant.components.tts.async_get_media_source_audio", new=get_mp3
):
await satellite._stream_tts_audio("test-media-id")
mock_client.send_voice_assistant_audio.assert_not_called()
# Needs to be the correct sample rate, etc.
async def get_bad_wav(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(48000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(b"test-wav")
return ("wav", wav_io.getvalue())
with patch(
"homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav
):
await satellite._stream_tts_audio("test-media-id")
mock_client.send_voice_assistant_audio.assert_not_called()
# Check that TTS_STREAM_* events still get sent after cancel
media_fetched = asyncio.Event()
async def get_slow_wav(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
media_fetched.set()
await asyncio.sleep(1)
return ("wav", mock_wav)
mock_client.send_voice_assistant_event.reset_mock()
with patch(
"homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav
):
task = asyncio.create_task(satellite._stream_tts_audio("test-media-id"))
async with asyncio.timeout(1):
# Wait for media to be fetched
await media_fetched.wait()
# Cancel task
task.cancel()
await task
# No audio should have gone out
mock_client.send_voice_assistant_audio.assert_not_called()
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
# The TTS_STREAM_* events should have gone out
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
{},
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)

View file

@ -2,7 +2,7 @@
import asyncio
from collections.abc import Awaitable, Callable
from unittest.mock import AsyncMock, call, patch
from unittest.mock import AsyncMock, call
from aioesphomeapi import (
APIClient,
@ -17,7 +17,6 @@ from aioesphomeapi import (
UserService,
UserServiceArg,
UserServiceArgType,
VoiceAssistantFeature,
)
import pytest
@ -29,10 +28,6 @@ from homeassistant.components.esphome.const import (
DOMAIN,
STABLE_BLE_VERSION_STR,
)
from homeassistant.components.esphome.voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantUDPPipeline,
)
from homeassistant.const import (
CONF_HOST,
CONF_PASSWORD,
@ -44,7 +39,7 @@ from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import device_registry as dr, issue_registry as ir
from homeassistant.setup import async_setup_component
from .conftest import _ONE_SECOND, MockESPHomeDevice
from .conftest import MockESPHomeDevice
from tests.common import MockConfigEntry, async_capture_events, async_mock_service
@ -1214,102 +1209,3 @@ async def test_entry_missing_unique_id(
await mock_esphome_device(mock_client=mock_client, mock_storage=True)
await hass.async_block_till_done()
assert entry.unique_id == "11:22:33:44:55:aa"
async def test_manager_voice_assistant_handlers_api(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
caplog: pytest.LogCaptureFixture,
mock_voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the handlers are correctly executed in manager.py."""
device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()
with (
patch(
"homeassistant.components.esphome.manager.VoiceAssistantAPIPipeline",
new=mock_voice_assistant_api_pipeline,
),
):
port: int | None = await device.mock_voice_assistant_handle_start(
"", 0, None, None
)
assert port == 0
port: int | None = await device.mock_voice_assistant_handle_start(
"", 0, None, None
)
assert "Previous Voice assistant pipeline was not stopped" in caplog.text
await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND))
mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_called_with(
bytes(_ONE_SECOND)
)
mock_voice_assistant_api_pipeline.receive_audio_bytes.reset_mock()
await device.mock_voice_assistant_handle_stop()
mock_voice_assistant_api_pipeline.handle_finished()
await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND))
mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_not_called()
async def test_manager_voice_assistant_handlers_udp(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_voice_assistant_udp_pipeline: VoiceAssistantUDPPipeline,
) -> None:
"""Test the handlers are correctly executed in manager.py."""
device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
with (
patch(
"homeassistant.components.esphome.manager.VoiceAssistantUDPPipeline",
new=mock_voice_assistant_udp_pipeline,
),
):
await device.mock_voice_assistant_handle_start("", 0, None, None)
mock_voice_assistant_udp_pipeline.run_pipeline.assert_called()
await device.mock_voice_assistant_handle_stop()
mock_voice_assistant_udp_pipeline.handle_finished()
mock_voice_assistant_udp_pipeline.stop.assert_called()
mock_voice_assistant_udp_pipeline.close.assert_called()

View file

@ -1,964 +0,0 @@
"""Test ESPHome voice assistant server."""
import asyncio
from collections.abc import Awaitable, Callable
import io
import socket
from unittest.mock import ANY, Mock, patch
import wave
from aioesphomeapi import (
APIClient,
EntityInfo,
EntityState,
UserService,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
import pytest
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineStage,
)
from homeassistant.components.assist_pipeline.error import (
PipelineNotFound,
WakeWordDetectionAborted,
WakeWordDetectionError,
)
from homeassistant.components.esphome import DomainData
from homeassistant.components.esphome.voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantUDPPipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent as intent_helper
import homeassistant.helpers.device_registry as dr
from .conftest import _ONE_SECOND, MockESPHomeDevice
_TEST_INPUT_TEXT = "This is an input test"
_TEST_OUTPUT_TEXT = "This is an output test"
_TEST_OUTPUT_URL = "output.mp3"
_TEST_MEDIA_ID = "12345"
@pytest.fixture
def voice_assistant_udp_pipeline(
hass: HomeAssistant,
) -> VoiceAssistantUDPPipeline:
"""Return the UDP pipeline factory."""
def _voice_assistant_udp_server(entry):
entry_data = DomainData.get(hass).get_entry_data(entry)
server: VoiceAssistantUDPPipeline = None
def handle_finished():
nonlocal server
assert server is not None
server.close()
server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished)
return server # noqa: RET504
return _voice_assistant_udp_server
@pytest.fixture
def voice_assistant_api_pipeline(
hass: HomeAssistant,
mock_client,
mock_voice_assistant_api_entry,
) -> VoiceAssistantAPIPipeline:
"""Return the API Pipeline factory."""
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry)
return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client)
@pytest.fixture
def voice_assistant_udp_pipeline_v1(
voice_assistant_udp_pipeline,
mock_voice_assistant_v1_entry,
) -> VoiceAssistantUDPPipeline:
"""Return the UDP pipeline."""
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry)
@pytest.fixture
def voice_assistant_udp_pipeline_v2(
voice_assistant_udp_pipeline,
mock_voice_assistant_v2_entry,
) -> VoiceAssistantUDPPipeline:
"""Return the UDP pipeline."""
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry)
@pytest.fixture
def mock_wav() -> bytes:
"""Return one second of empty WAV audio."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
return wav_io.getvalue()
async def test_pipeline_events(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test that the pipeline function is called."""
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == "mock-device-id"
event_callback = kwargs["event_callback"]
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_END,
data={"wake_word_output": {}},
)
)
# Fake events
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={"tts_input": _TEST_OUTPUT_TEXT},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": _TEST_OUTPUT_URL}},
)
)
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert data is not None
assert data["text"] == _TEST_INPUT_TEXT
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert data is not None
assert data["text"] == _TEST_OUTPUT_TEXT
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert data is not None
assert data["url"] == _TEST_OUTPUT_URL
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert data is None
voice_assistant_udp_pipeline_v1.handle_event = handle_event
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
voice_assistant_udp_pipeline_v1.transport = Mock()
await voice_assistant_udp_pipeline_v1.run_pipeline(
device_id="mock-device-id", conversation_id=None
)
@pytest.mark.usefixtures("socket_enabled")
async def test_udp_server(
unused_udp_port_factory: Callable[[], int],
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test the UDP server runs and queues incoming data."""
port_to_use = unused_udp_port_factory()
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
):
port = await voice_assistant_udp_pipeline_v1.start_server()
assert port == port_to_use
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
sock.sendto(b"test", ("127.0.0.1", port))
# Give the socket some time to send/receive the data
async with asyncio.timeout(1):
while voice_assistant_udp_pipeline_v1.queue.qsize() == 0:
await asyncio.sleep(0.1)
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
voice_assistant_udp_pipeline_v1.stop()
voice_assistant_udp_pipeline_v1.close()
assert voice_assistant_udp_pipeline_v1.transport.is_closing()
async def test_udp_server_queue(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test the UDP server queues incoming data."""
voice_assistant_udp_pipeline_v1.started = True
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2
async for data in voice_assistant_udp_pipeline_v1._iterate_packets():
assert data == bytes(1024)
break
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed
voice_assistant_udp_pipeline_v1.stop()
assert (
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
) # An empty message added by stop
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
assert (
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
) # No new messages added after stop
voice_assistant_udp_pipeline_v1.close()
# Stopping the UDP server should cause _iterate_packets to break out
# immediately without yielding any data.
has_data = False
async for _data in voice_assistant_udp_pipeline_v1._iterate_packets():
has_data = True
assert not has_data, "Server was stopped"
async def test_api_pipeline_queue(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the API pipeline queues incoming data."""
voice_assistant_api_pipeline.started = True
assert voice_assistant_api_pipeline.queue.qsize() == 0
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
assert voice_assistant_api_pipeline.queue.qsize() == 1
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
assert voice_assistant_api_pipeline.queue.qsize() == 2
async for data in voice_assistant_api_pipeline._iterate_packets():
assert data == bytes(1024)
break
assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed
voice_assistant_api_pipeline.stop()
assert (
voice_assistant_api_pipeline.queue.qsize() == 2
) # An empty message added by stop
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
assert (
voice_assistant_api_pipeline.queue.qsize() == 2
) # No new messages added after stop
# Stopping the API Pipeline should cause _iterate_packets to break out
# immediately without yielding any data.
has_data = False
async for _data in voice_assistant_api_pipeline._iterate_packets():
has_data = True
assert not has_data, "Pipeline was stopped"
async def test_error_calls_handle_finished(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test that the handle_finished callback is called when an error occurs."""
voice_assistant_udp_pipeline_v1.handle_finished = Mock()
voice_assistant_udp_pipeline_v1.error_received(Exception())
voice_assistant_udp_pipeline_v1.handle_finished.assert_called()
@pytest.mark.usefixtures("socket_enabled")
async def test_udp_server_multiple(
unused_udp_port_factory: Callable[[], int],
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test that the UDP server raises an error if started twice."""
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
):
await voice_assistant_udp_pipeline_v1.start_server()
with (
patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
),
pytest.raises(RuntimeError),
):
await voice_assistant_udp_pipeline_v1.start_server()
@pytest.mark.usefixtures("socket_enabled")
async def test_udp_server_after_stopped(
unused_udp_port_factory: Callable[[], int],
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test that the UDP server raises an error if started after stopped."""
voice_assistant_udp_pipeline_v1.close()
with (
patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
),
pytest.raises(RuntimeError),
):
await voice_assistant_udp_pipeline_v1.start_server()
async def test_events_converted_correctly(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the pipeline events produce the correct data to send to the device."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts",
):
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None
)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "text"}},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": "text"}
)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None
)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "conversation-id",
}
},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
{"conversation_id": "conversation-id"},
)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={"tts_input": "text"},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": "text"}
)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": "url", "media_id": "media-id"}},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, {"url": "url"}
)
async def test_unknown_event_type(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the API pipeline does not call handle_event for unknown events."""
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type="unknown-event",
data={},
)
)
assert not voice_assistant_api_pipeline.handle_event.called
async def test_error_event_type(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the API pipeline calls event handler with error."""
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.ERROR,
data={"code": "code", "message": "message"},
)
)
voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{"code": "code", "message": "message"},
)
async def test_send_tts_not_called(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None:
"""Test the UDP server with a v1 device does not call _send_tts."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts:
voice_assistant_udp_pipeline_v1._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
mock_send_tts.assert_not_called()
async def test_send_tts_called_udp(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
) -> None:
"""Test the UDP server with a v2 device calls _send_tts."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts:
voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
async def test_send_tts_called_api(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the API pipeline calls _send_tts."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts:
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
async def test_send_tts_not_called_when_empty(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the pipelines do not call _send_tts when the output is empty."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts:
voice_assistant_udp_pipeline_v1._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
)
mock_send_tts.assert_not_called()
voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
)
mock_send_tts.assert_not_called()
voice_assistant_api_pipeline._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
)
mock_send_tts.assert_not_called()
async def test_send_tts_udp(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
mock_wav: bytes,
) -> None:
"""Test the UDP server calls sendto to transmit audio data to device."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", mock_wav),
):
voice_assistant_udp_pipeline_v2.started = True
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
with patch.object(
voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False
):
voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": _TEST_MEDIA_ID,
"url": _TEST_OUTPUT_URL,
}
},
)
)
await voice_assistant_udp_pipeline_v2._tts_done.wait()
voice_assistant_udp_pipeline_v2.transport.sendto.assert_called()
async def test_send_tts_api(
hass: HomeAssistant,
mock_client: APIClient,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
mock_wav: bytes,
) -> None:
"""Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", mock_wav),
):
voice_assistant_api_pipeline.started = True
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": _TEST_MEDIA_ID,
"url": _TEST_OUTPUT_URL,
}
},
)
)
await voice_assistant_api_pipeline._tts_done.wait()
mock_client.send_voice_assistant_audio.assert_called()
async def test_send_tts_wrong_sample_rate(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test that only 16000Hz audio will be streamed."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes),
):
voice_assistant_api_pipeline.started = True
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
assert voice_assistant_api_pipeline._tts_task is not None
with pytest.raises(ValueError):
await voice_assistant_api_pipeline._tts_task
async def test_send_tts_wrong_format(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test that only WAV audio will be streamed."""
with (
patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)),
),
):
voice_assistant_api_pipeline.started = True
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
assert voice_assistant_api_pipeline._tts_task is not None
with pytest.raises(ValueError):
await voice_assistant_api_pipeline._tts_task
async def test_send_tts_not_started(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
mock_wav: bytes,
) -> None:
"""Test the UDP server does not call sendto when not started."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", mock_wav),
):
voice_assistant_udp_pipeline_v2.started = False
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
await voice_assistant_udp_pipeline_v2._tts_done.wait()
voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called()
async def test_send_tts_transport_none(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
mock_wav: bytes,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the UDP server does not call sendto when transport is None."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", mock_wav),
):
voice_assistant_udp_pipeline_v2.started = True
voice_assistant_udp_pipeline_v2.transport = None
voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
await voice_assistant_udp_pipeline_v2._tts_done.wait()
assert "No transport to send audio to" in caplog.text
async def test_wake_word(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test that the pipeline is set to start with Wake word."""
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
assert start_stage == PipelineStage.WAKE_WORD
with (
patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch("asyncio.Event.wait"), # TTS wait event
):
await voice_assistant_api_pipeline.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
)
async def test_wake_word_exception(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test that the pipeline is set to start with Wake word."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found")
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert data is not None
assert data["code"] == "pipeline-not-found"
assert data["message"] == "Pipeline not found"
voice_assistant_api_pipeline.handle_event = handle_event
await voice_assistant_api_pipeline.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
)
async def test_wake_word_abort_exception(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test that the pipeline is set to start with Wake word."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise WakeWordDetectionAborted
with (
patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event,
):
await voice_assistant_api_pipeline.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
)
mock_handle_event.assert_not_called()
async def test_timer_events(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that injecting timer events results in the correct api client calls."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.TIMERS
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_called_with(
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
ANY,
"test timer",
total_seconds,
total_seconds,
True,
)
# Increase timer beyond original time and check total_seconds has increased
mock_client.send_voice_assistant_timer_event.reset_mock()
total_seconds += 5 * 60
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_INCREASE_TIMER,
{
"name": {"value": "test timer"},
"minutes": {"value": 5},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_called_with(
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
ANY,
"test timer",
total_seconds,
ANY,
True,
)
async def test_unknown_timer_event(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that unknown (new) timer event types do not result in api calls."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.TIMERS
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
with patch(
"homeassistant.components.esphome.voice_assistant._TIMER_EVENT_TYPES.from_hass",
side_effect=KeyError,
):
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_not_called()
async def test_invalid_pipeline_id(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test that the pipeline is set to start with Wake word."""
invalid_pipeline_id = "invalid-pipeline-id"
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {invalid_pipeline_id} not found"
)
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert data is not None
assert data["code"] == "pipeline_not_found"
assert data["message"] == f"Pipeline {invalid_pipeline_id} not found"
voice_assistant_api_pipeline.handle_event = handle_event
await voice_assistant_api_pipeline.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
)

View file

@ -14,6 +14,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.components.tts.conftest import (
mock_tts_cache_dir_fixture_autouse, # noqa: F401
)
@pytest.fixture(autouse=True)

File diff suppressed because one or more lines are too long

View file

@ -3,15 +3,26 @@
import asyncio
import io
from pathlib import Path
import time
from unittest.mock import AsyncMock, Mock, patch
import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from voip_utils import CallInfo
from homeassistant.components import assist_pipeline, voip
from homeassistant.components.voip.devices import VoIPDevice
from homeassistant.components import assist_pipeline, assist_satellite, voip
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteState,
)
from homeassistant.components.voip import HassVoipDatagramProtocol
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
from homeassistant.const import STATE_OFF, STATE_ON, Platform
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import async_setup_component
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
return wav_io.getvalue()
def async_get_satellite_entity(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> AssistSatelliteEntity | None:
"""Get Assist satellite entity."""
ent_reg = er.async_get(hass)
satellite_entity_id = ent_reg.async_get_entity_id(
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
)
if satellite_entity_id is None:
return None
component: EntityComponent[AssistSatelliteEntity] = hass.data[
assist_satellite.DOMAIN
]
return component.get_entity(satellite_entity_id)
async def test_is_valid_call(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that a call is now allowed from an unknown device."""
assert await async_setup_component(hass, "voip", {})
protocol = HassVoipDatagramProtocol(hass, voip_devices)
assert not protocol.is_valid_call(call_info)
ent_reg = er.async_get(hass)
allowed_call_entity_id = ent_reg.async_get_entity_id(
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
)
assert allowed_call_entity_id is not None
state = hass.states.get(allowed_call_entity_id)
assert state is not None
assert state.state == STATE_OFF
# Allow calls
hass.states.async_set(allowed_call_entity_id, STATE_ON)
assert protocol.is_valid_call(call_info)
async def test_calls_not_allowed(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when calls aren't allowed."""
assert await async_setup_component(hass, "voip", {})
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
# Test the playback
done = asyncio.Event()
played_audio_bytes = b""
def send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be problem.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
protocol.transport = Mock()
protocol.loop_delay = 0
with patch.object(protocol, "send_audio", send_audio):
protocol.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()
async def test_pipeline_not_found(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when a pipeline isn't found."""
assert await async_setup_component(hass, "voip", {})
with patch(
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
):
protocol: PreRecordMessageProtocol = make_protocol(
hass, voip_devices, call_info
)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
async def test_satellite_prepared(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that satellite is prepared for a call."""
assert await async_setup_component(hass, "voip", {})
pipeline = assist_pipeline.Pipeline(
conversation_engine="test",
conversation_language="en",
language="en",
name="test",
stt_engine="test",
stt_language="en",
tts_engine="test",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
with (
patch(
"homeassistant.components.voip.voip.async_get_pipeline",
return_value=pipeline,
),
):
protocol = make_protocol(hass, voip_devices, call_info)
assert protocol == satellite
async def test_pipeline(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
voip_user_id = satellite.config_entry.data["user"]
assert voip_user_id
return 0
# Satellite is muted until a call begins
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
done = asyncio.Event()
# Used to test that audio queue is cleared before pipeline starts
bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
async def async_pipeline_from_audio_stream(
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
):
assert context.user_id == voip_user_id
assert device_id == voip_device.device_id
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
in_command = False
async for chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
assert _chunk != bad_chunk
assert chunk != bad_chunk
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Test empty data
event_callback(
@ -71,6 +229,38 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_START,
data={"engine": "test", "metadata": {}},
)
)
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_START,
data={
"engine": "test",
"language": hass.config.language,
"intent_input": "fake-text",
"conversation_id": None,
"device_id": None,
},
)
)
assert satellite.state == AssistSatelliteState.PROCESSING
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
@ -83,6 +273,21 @@ async def test_pipeline(
)
)
# Fake tts result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_START,
data={
"engine": "test",
"language": hass.config.language,
"voice": "test",
"tts_input": "fake-text",
},
)
)
assert satellite.state == AssistSatelliteState.RESPONDING
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
@ -91,6 +296,18 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.RUN_END
)
)
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
done.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
@ -100,102 +317,56 @@ async def test_pipeline(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
)
rtp_protocol.transport = Mock()
satellite._tones = Tones(0)
satellite.transport = Mock()
satellite.connection_made(satellite.transport)
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
# Ensure audio queue is cleared before pipeline starts
rtp_protocol._audio_queue.put_nowait(bad_chunk)
satellite._audio_queue.put_nowait(bad_chunk)
def send_audio(*args, **kwargs):
# Test finished successfully
done.set()
# Don't send audio
pass
rtp_protocol.send_audio = Mock(side_effect=send_audio)
satellite.send_audio = Mock(side_effect=send_audio)
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes aggressive VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
"""Test timeout during pipeline run."""
assert await async_setup_component(hass, "voip", {})
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
await asyncio.sleep(10)
with (
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
return_value=True,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
pipeline_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
await done.wait()
# Finished speaking
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
async def test_stt_stream_timeout(
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
) -> None:
"""Test timeout in STT stream during pipeline run."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
pass
with patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
audio_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
satellite._tones = Tones(0)
satellite._audio_chunk_timeout = 0.001
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
satellite.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
async def test_tts_timeout(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -278,15 +448,7 @@ async def test_tts_timeout(
tone_bytes = bytes([1, 2, 3, 4])
def send_audio(audio_bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
# Block here to force a timeout in _send_tts
time.sleep(2)
async def async_send_audio(audio_bytes, **kwargs):
async def async_send_audio(audio_bytes: bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
@ -303,37 +465,22 @@ async def test_tts_timeout(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
tts_extra_timeout=0.001,
listening_tone_enabled=True,
processing_tone_enabled=True,
error_tone_enabled=True,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
)
rtp_protocol._tone_bytes = tone_bytes
rtp_protocol._processing_bytes = tone_bytes
rtp_protocol._error_bytes = tone_bytes
rtp_protocol.transport = Mock()
rtp_protocol.send_audio = Mock()
satellite._tts_extra_timeout = 0.001
for tone in Tones:
satellite._tone_bytes[tone] = tone_bytes
original_send_tts = rtp_protocol._send_tts
satellite.transport = Mock()
satellite.send_audio = Mock()
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -342,17 +489,17 @@ async def test_tts_timeout(
done.set()
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -361,26 +508,34 @@ async def test_tts_timeout(
async def test_tts_wrong_extension(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
async def test_tts_wrong_wav_format(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
async def test_empty_tts_output(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -605,37 +754,78 @@ async def test_empty_tts_output(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to finish
async with asyncio.timeout(1):
await rtp_protocol._tts_done.wait()
await satellite._tts_done.wait()
mock_send_tts.assert_not_called()
async def test_pipeline_error(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pipeline error causes the error tone to be played."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
played_audio_bytes = b""
async def async_pipeline_from_audio_stream(*args, **kwargs):
# Fake error
event_callback = kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.ERROR,
data={"code": "error-code", "message": "error message"},
)
)
async def async_send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be error.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
satellite._tones = Tones.ERROR
satellite.transport = Mock()
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for error tone to be played
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -3,8 +3,8 @@
from unittest.mock import Mock, patch
from homeassistant.components import assist_pipeline
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.pipeline import PipelineData
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
timer_handle.cancel()
timer_handle2.cancel()
timer_handle3.cancel()
async def test_queue_to_iterable() -> None:
"""Test queue_to_iterable."""
queue: asyncio.Queue[int | None] = asyncio.Queue()
expected_items = list(range(10))
for i in expected_items:
await queue.put(i)
# Will terminate the stream
await queue.put(None)
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
assert expected_items == actual_items
# Check timeout
assert queue.empty()
# Time out on first item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
await asyncio.sleep(1)
# Check timeout on second item
assert queue.empty()
await queue.put(12345)
# Time out on second item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
if item != 12345:
await asyncio.sleep(1)
assert queue.empty()