Add Assist satellite entity + VoIP (#123830)

* Add assist_satellite and implement VoIP

* Fix tests

* More tests

* Improve test

* Update entity state

* Set state correctly

* Move more functionality into base class

* Move RTP protocol into entity

* Fix tests

* Remove string

* Move to util method

* Align states better with pipeline events

* Remove public async_get_satellite_entity

* WAITING_FOR_WAKE_WORD

* Pass entity ids for pipeline/vad sensitivity

* Remove connect/disconnect

* Clean up

* Final cleanup
This commit is contained in:
Michael Hansen 2024-08-25 09:19:36 -05:00 committed by GitHub
parent 36bfd7b9ce
commit 644427ecc7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1089 additions and 647 deletions

View file

@ -143,6 +143,7 @@ build.json @home-assistant/supervisor
/tests/components/aseko_pool_live/ @milanmeu
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL

View file

@ -16,6 +16,7 @@ from .const import (
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
OPTION_PREFERRED,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
@ -57,6 +58,7 @@ __all__ = (
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
"OPTION_PREFERRED",
"SAMPLES_PER_CHUNK",
"SAMPLE_RATE",
"SAMPLE_WIDTH",

View file

@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
MS_PER_CHUNK = 10
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred"

View file

@ -504,7 +504,7 @@ class AudioSettings:
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
silence_seconds: float = 0.5
silence_seconds: float = 0.7
"""Seconds of silence after voice command has ended."""
def __post_init__(self) -> None:
@ -906,6 +906,8 @@ class PipelineRun:
metadata,
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
)
except (asyncio.CancelledError, TimeoutError):
raise # expected
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
raise SpeechToTextError(

View file

@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@callback
def get_chosen_pipeline(

View file

@ -0,0 +1,44 @@
"""Base class for assist satellite entities."""
import logging
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .entity import AssistSatelliteEntity
from .models import AssistSatelliteState
__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
]
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
_LOGGER, DOMAIN, hass
)
await component.async_setup(config)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View file

@ -0,0 +1,3 @@
"""Constants for assist satellite."""
DOMAIN = "assist_satellite"

View file

@ -0,0 +1,151 @@
"""Assist satellite entity."""
from collections.abc import AsyncIterable
import time
from typing import Final
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.const import EntityCategory
from homeassistant.core import Context
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .models import AssistSatelliteState
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
entity_description = EntityDescription(
key="assist_satellite",
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
_attr_has_entity_name = True
_attr_name = None
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False
async def _async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
pipeline_entity_id: str | None = None,
vad_sensitivity_entity_id: str | None = None,
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
pipeline_id: str | None = None
vad_sensitivity = vad.VadSensitivity.DEFAULT
if pipeline_entity_id:
# Resolve pipeline by name
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
if (pipeline_entity_state is not None) and (
pipeline_entity_state.state != OPTION_PREFERRED
):
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
pipeline_id = pipeline.id
break
if vad_sensitivity_entity_id:
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
if vad_sensitivity_state is not None:
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
device_id: str | None = None
if self.registry_entry is not None:
device_id = self.registry_entry.device_id
# Refresh context if necessary
if (
(self._context is None)
or (self._context_set is None)
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
):
self.async_set_context(Context())
assert self._context is not None
# Reset conversation id if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = ulid.ulid()
# Update timeout
self._conversation_id_time = time.monotonic()
# Set entity state based on pipeline events
self._run_has_tts = False
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self.on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=pipeline_id,
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
),
)
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type == PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type == PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type == PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type == PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type == PipelineEventType.RUN_END:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state
self.async_write_ha_state()
def tts_response_finished(self) -> None:
"""Tell entity that the text-to-speech response has finished playing."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)

View file

@ -0,0 +1,9 @@
{
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@synesthesiam"],
"config_flow": false,
"dependencies": ["assist_pipeline", "stt"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity"
}

View file

@ -0,0 +1,19 @@
"""Models for assist satellite."""
from enum import StrEnum
class AssistSatelliteState(StrEnum):
"""Valid states of an Assist satellite entity."""
LISTENING_WAKE_WORD = "listening_wake_word"
"""Device is streaming audio for wake word detection to Home Assistant."""
LISTENING_COMMAND = "listening_command"
"""Device is streaming audio with the voice command to Home Assistant."""
PROCESSING = "processing"
"""Home Assistant is processing the voice command."""
RESPONDING = "responding"
"""Device is speaking the response."""

View file

@ -0,0 +1,14 @@
{
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
}
}
}

View file

@ -20,6 +20,7 @@ from .devices import VoIPDevices
from .voip import HassVoipDatagramProtocol
PLATFORMS = (
Platform.ASSIST_SATELLITE,
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,

View file

@ -0,0 +1,298 @@
"""Assist satellite entity for VoIP integration."""
from __future__ import annotations
import asyncio
from enum import IntFlag
from functools import partial
import io
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Final
import wave
from voip_utils import RtpDatagramProtocol
from homeassistant.components import tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_satellite import AssistSatelliteEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.async_ import queue_to_iterable
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity
if TYPE_CHECKING:
from . import DomainData
_LOGGER = logging.getLogger(__name__)
_PIPELINE_TIMEOUT_SEC: Final = 30
class Tones(IntFlag):
"""Feedback tones for specific events."""
LISTENING = 1
PROCESSING = 2
ERROR = 4
_TONE_FILENAMES: dict[Tones, str] = {
Tones.LISTENING: "tone.pcm",
Tones.PROCESSING: "processing.pcm",
Tones.ERROR: "error.pcm",
}
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP Assist satellite entity."""
domain_data: DomainData = hass.data[DOMAIN]
@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
domain_data.devices.async_add_new_device_listener(async_add_device)
entities: list[VoIPEntity] = [
VoipAssistSatellite(hass, device, config_entry)
for device in domain_data.devices
]
async_add_entities(entities)
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
"""Assist satellite for VoIP devices."""
def __init__(
self,
hass: HomeAssistant,
voip_device: VoIPDevice,
config_entry: ConfigEntry,
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
) -> None:
"""Initialize an Assist satellite."""
VoIPEntity.__init__(self, voip_device)
AssistSatelliteEntity.__init__(self)
RtpDatagramProtocol.__init__(self)
self.config_entry = config_entry
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._audio_chunk_timeout: float = 2.0
self._pipeline_task: asyncio.Task | None = None
self._pipeline_had_error: bool = False
self._tts_done = asyncio.Event()
self._tts_extra_timeout: float = 1.0
self._tone_bytes: dict[Tones, bytes] = {}
self._tones = tones
self._processing_tone_done = asyncio.Event()
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
self.voip_device.protocol = self
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
if self.voip_device.protocol == self:
self.voip_device.protocol = None
# -------------------------------------------------------------------------
# VoIP
# -------------------------------------------------------------------------
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
self.voip_device.set_is_active(True)
# Play listening tone at the start of each cycle
await self._play_tone(Tones.LISTENING, silence_before=0.2)
try:
self._tts_done.clear()
# Run pipeline with a timeout
_LOGGER.debug("Starting pipeline")
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
audio_stream=queue_to_iterable(
self._audio_queue, timeout=self._audio_chunk_timeout
),
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
self.hass
),
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
self.hass
),
)
if self._pipeline_had_error:
self._pipeline_had_error = False
await self._play_tone(Tones.ERROR)
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except (asyncio.CancelledError, TimeoutError):
# Expected after caller hangs up
_LOGGER.debug("Pipeline cancelled or timed out")
self.disconnect()
self._clear_audio_queue()
finally:
self.voip_device.set_is_active(False)
# Allow pipeline to run again
self._pipeline_task = None
def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
super().on_pipeline_event(event)
if event.type == PipelineEventType.STT_END:
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
self._processing_tone_done.clear()
self.config_entry.async_create_background_task(
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
)
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.config_entry.async_create_background_task(
self.hass,
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS when pipeline is finished.
self._pipeline_had_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return # not connected
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
# Don't overlap TTS and processing beep
await self._processing_tone_done.wait()
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
# Update satellite state
self.tts_response_finished()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
"""Play a tone as feedback to the user if it's enabled."""
if (self._tones & tone) != tone:
return # not enabled
if tone not in self._tone_bytes:
# Do I/O in executor
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
self._load_pcm,
_TONE_FILENAMES[tone],
)
await self._async_send_audio(
self._tone_bytes[tone],
silence_before=silence_before,
)
if tone == Tones.PROCESSING:
self._processing_tone_done.set()
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()

View file

@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
"""Call when entity about to be added to hass."""
await super().async_added_to_hass()
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
self.async_on_remove(
self.voip_device.async_listen_update(self._is_active_changed)
)
@callback
def _is_active_changed(self, device: VoIPDevice) -> None:
"""Call when active state changed."""
self._attr_is_on = self._device.is_active
self._attr_is_on = self.voip_device.is_active
self.async_write_ha_state()

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterator
from dataclasses import dataclass, field
from voip_utils import CallInfo
from voip_utils import CallInfo, VoipDatagramProtocol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Event, HomeAssistant, callback
@ -22,6 +22,7 @@ class VoIPDevice:
device_id: str
is_active: bool = False
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
protocol: VoipDatagramProtocol | None = None
@callback
def set_is_active(self, active: bool) -> None:
@ -56,6 +57,18 @@ class VoIPDevice:
return False
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for pipeline select."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for VAD sensitivity."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
)
class VoIPDevices:
"""Class to store devices."""

View file

@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
_attr_has_entity_name = True
_attr_should_poll = False
def __init__(self, device: VoIPDevice) -> None:
def __init__(self, voip_device: VoIPDevice) -> None:
"""Initialize VoIP entity."""
self._device = device
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
self.voip_device = voip_device
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, device.voip_id)},
identifiers={(DOMAIN, voip_device.voip_id)},
)

View file

@ -3,7 +3,7 @@
"name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline"],
"dependencies": ["assist_pipeline", "assist_satellite"],
"documentation": "https://www.home-assistant.io/integrations/voip",
"iot_class": "local_push",
"quality_scale": "internal",

View file

@ -10,6 +10,16 @@
}
},
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
}
}
},
"binary_sensor": {
"call_in_progress": {
"name": "Call in progress"

View file

@ -3,15 +3,11 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable, MutableSequence, Sequence
from functools import partial
import io
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING
import wave
from voip_utils import (
CallInfo,
@ -21,33 +17,19 @@ from voip_utils import (
VoipDatagramProtocol,
)
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline import (
Pipeline,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
async_get_pipeline,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.audio_enhancer import (
AudioEnhancer,
MicroVadEnhancer,
)
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
from homeassistant.util.ulid import ulid_now
from homeassistant.core import HomeAssistant
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
from .devices import VoIPDevices
_LOGGER = logging.getLogger(__name__)
@ -60,11 +42,8 @@ def make_protocol(
) -> VoipDatagramProtocol:
"""Plays a pre-recorded message if pipeline is misconfigured."""
voip_device = devices.async_get_or_create(call_info)
pipeline_id = pipeline_select.get_chosen_pipeline(
hass,
DOMAIN,
voip_device.voip_id,
)
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
try:
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
except PipelineNotFound:
@ -83,22 +62,18 @@ def make_protocol(
rtcp_state=rtcp_state,
)
vad_sensitivity = pipeline_select.get_vad_sensitivity(
hass,
DOMAIN,
voip_device.voip_id,
)
if (protocol := voip_device.protocol) is None:
raise ValueError("VoIP satellite not found")
# Pipeline is properly configured
return PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(user_id=devices.config_entry.data["user"]),
opus_payload_type=call_info.opus_payload_type,
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
rtcp_state=rtcp_state,
)
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol.rtcp_state = rtcp_state
if protocol.rtcp_state is not None:
# Automatically disconnect when BYE is received over RTCP
protocol.rtcp_state.bye_callback = protocol.disconnect
return protocol
class HassVoipDatagramProtocol(VoipDatagramProtocol):
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
await self._closed_event.wait()
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
"""Run a voice assistant pipeline in a loop for a VoIP call."""
def __init__(
self,
hass: HomeAssistant,
language: str,
voip_device: VoIPDevice,
context: Context,
opus_payload_type: int,
pipeline_timeout: float = 30.0,
audio_timeout: float = 2.0,
buffered_chunks_before_speech: int = 100,
listening_tone_enabled: bool = True,
processing_tone_enabled: bool = True,
error_tone_enabled: bool = True,
tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0,
silence_seconds: float = 1.0,
rtcp_state: RtcpState | None = None,
) -> None:
"""Set up pipeline RTP server."""
super().__init__(
rate=RATE,
width=WIDTH,
channels=CHANNELS,
opus_payload_type=opus_payload_type,
rtcp_state=rtcp_state,
)
self.hass = hass
self.language = language
self.voip_device = voip_device
self.pipeline: Pipeline | None = None
self.pipeline_timeout = pipeline_timeout
self.audio_timeout = audio_timeout
self.buffered_chunks_before_speech = buffered_chunks_before_speech
self.listening_tone_enabled = listening_tone_enabled
self.processing_tone_enabled = processing_tone_enabled
self.error_tone_enabled = error_tone_enabled
self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout
self.silence_seconds = silence_seconds
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context
self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
self._tts_done = asyncio.Event()
self._session_id: str | None = None
self._tone_bytes: bytes | None = None
self._processing_bytes: bytes | None = None
self._error_bytes: bytes | None = None
self._pipeline_error: bool = False
def connection_made(self, transport):
"""Server is ready."""
super().connection_made(transport)
self.voip_device.set_is_active(True)
def connection_lost(self, exc):
"""Handle connection is lost or closed."""
super().connection_lost(exc)
self.voip_device.set_is_active(False)
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.hass.async_create_background_task(
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
if self._session_id is None:
self._session_id = ulid_now()
# Play listening tone at the start of each cycle
if self.listening_tone_enabled:
await self._play_listening_tone()
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
audio_enhancer = MicroVadEnhancer(0, 0, True)
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)
speech_detected = await self._wait_for_speech(
segmenter,
audio_enhancer,
chunk_buffer,
)
if not speech_detected:
_LOGGER.debug("No speech detected")
return
_LOGGER.debug("Starting pipeline")
self._tts_done.clear()
async def stt_stream():
try:
async for chunk in self._segment_audio(
segmenter,
audio_enhancer,
chunk_buffer,
):
yield chunk
if self.processing_tone_enabled:
await self._play_processing_tone()
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Audio timeout")
self._session_id = None
self.disconnect()
finally:
self._clear_audio_queue()
# Run pipeline with a timeout
async with asyncio.timeout(self.pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="wav",
)
if self._pipeline_error:
self._pipeline_error = False
if self.error_tone_enabled:
await self._play_error_tone()
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Pipeline timeout")
self._session_id = None
self.disconnect()
finally:
# Allow pipeline to run again
self._pipeline_task = None
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: MutableSequence[bytes],
):
"""Buffer audio chunks until speech is detected.
Returns True if speech was detected, False otherwise.
"""
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
chunk_buffer.append(chunk)
segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
)
if segmenter.in_command:
# Buffer until command starts
if len(vad_buffer) > 0:
chunk_buffer.append(vad_buffer.bytes())
return True
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
return False
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished."""
# Buffered chunks first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
if not segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
):
# Voice command is finished
break
yield chunk
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
def _clear_audio_queue(self) -> None:
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def _event_callback(self, event: PipelineEvent):
if not event.data:
return
if event.type == PipelineEventType.INTENT_END:
# Capture conversation id
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
tts_output = event.data["tts_output"]
if tts_output:
media_id = tts_output["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS
self._pipeline_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_listening_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is listening."""
if self._tone_bytes is None:
# Do I/O in executor
self._tone_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"tone.pcm",
)
await self._async_send_audio(
self._tone_bytes,
silence_before=self.tone_delay,
)
async def _play_processing_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is processing the voice command."""
if self._processing_bytes is None:
# Do I/O in executor
self._processing_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"processing.pcm",
)
await self._async_send_audio(self._processing_bytes)
async def _play_error_tone(self) -> None:
"""Play a tone to indicate a pipeline error occurred."""
if self._error_bytes is None:
# Do I/O in executor
self._error_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"error.pcm",
)
await self._async_send_audio(self._error_bytes)
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()
class PreRecordMessageProtocol(RtpDatagramProtocol):
"""Plays a pre-recorded message on a loop."""

View file

@ -41,6 +41,7 @@ class Platform(StrEnum):
AIR_QUALITY = "air_quality"
ALARM_CONTROL_PANEL = "alarm_control_panel"
ASSIST_SATELLITE = "assist_satellite"
BINARY_SENSOR = "binary_sensor"
BUTTON = "button"
CALENDAR = "calendar"

View file

@ -5,22 +5,28 @@ from __future__ import annotations
from asyncio import (
AbstractEventLoop,
Future,
Queue,
Semaphore,
Task,
TimerHandle,
gather,
get_running_loop,
timeout as async_timeout,
)
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
import concurrent.futures
import logging
import threading
from typing import Any
from typing_extensions import TypeVar
_LOGGER = logging.getLogger(__name__)
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
_DataT = TypeVar("_DataT", default=Any)
def create_eager_task[_T](
coro: Coroutine[Any, Any, _T],
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
"""Return a list of scheduled TimerHandles."""
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
return handles
async def queue_to_iterable(
queue: Queue[_DataT], timeout: float | None = None
) -> AsyncIterable[_DataT]:
"""Stream items from a queue until None with an optional timeout per item."""
if timeout is None:
while (item := await queue.get()) is not None:
yield item
else:
async with async_timeout(timeout):
item = await queue.get()
while item is not None:
yield item
async with async_timeout(timeout):
item = await queue.get()

File diff suppressed because one or more lines are too long

View file

@ -3,15 +3,26 @@
import asyncio
import io
from pathlib import Path
import time
from unittest.mock import AsyncMock, Mock, patch
import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from voip_utils import CallInfo
from homeassistant.components import assist_pipeline, voip
from homeassistant.components.voip.devices import VoIPDevice
from homeassistant.components import assist_pipeline, assist_satellite, voip
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteState,
)
from homeassistant.components.voip import HassVoipDatagramProtocol
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
from homeassistant.const import STATE_OFF, STATE_ON, Platform
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import async_setup_component
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
return wav_io.getvalue()
def async_get_satellite_entity(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> AssistSatelliteEntity | None:
"""Get Assist satellite entity."""
ent_reg = er.async_get(hass)
satellite_entity_id = ent_reg.async_get_entity_id(
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
)
if satellite_entity_id is None:
return None
component: EntityComponent[AssistSatelliteEntity] = hass.data[
assist_satellite.DOMAIN
]
return component.get_entity(satellite_entity_id)
async def test_is_valid_call(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that a call is now allowed from an unknown device."""
assert await async_setup_component(hass, "voip", {})
protocol = HassVoipDatagramProtocol(hass, voip_devices)
assert not protocol.is_valid_call(call_info)
ent_reg = er.async_get(hass)
allowed_call_entity_id = ent_reg.async_get_entity_id(
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
)
assert allowed_call_entity_id is not None
state = hass.states.get(allowed_call_entity_id)
assert state is not None
assert state.state == STATE_OFF
# Allow calls
hass.states.async_set(allowed_call_entity_id, STATE_ON)
assert protocol.is_valid_call(call_info)
async def test_calls_not_allowed(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when calls aren't allowed."""
assert await async_setup_component(hass, "voip", {})
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
# Test the playback
done = asyncio.Event()
played_audio_bytes = b""
def send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be problem.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
protocol.transport = Mock()
protocol.loop_delay = 0
with patch.object(protocol, "send_audio", send_audio):
protocol.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()
async def test_pipeline_not_found(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when a pipeline isn't found."""
assert await async_setup_component(hass, "voip", {})
with patch(
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
):
protocol: PreRecordMessageProtocol = make_protocol(
hass, voip_devices, call_info
)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
async def test_satellite_prepared(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that satellite is prepared for a call."""
assert await async_setup_component(hass, "voip", {})
pipeline = assist_pipeline.Pipeline(
conversation_engine="test",
conversation_language="en",
language="en",
name="test",
stt_engine="test",
stt_language="en",
tts_engine="test",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
with (
patch(
"homeassistant.components.voip.voip.async_get_pipeline",
return_value=pipeline,
),
):
protocol = make_protocol(hass, voip_devices, call_info)
assert protocol == satellite
async def test_pipeline(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
voip_user_id = satellite.config_entry.data["user"]
assert voip_user_id
return 0
# Satellite is muted until a call begins
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
done = asyncio.Event()
# Used to test that audio queue is cleared before pipeline starts
bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
async def async_pipeline_from_audio_stream(
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
):
assert context.user_id == voip_user_id
assert device_id == voip_device.device_id
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
in_command = False
async for chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
assert _chunk != bad_chunk
assert chunk != bad_chunk
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Test empty data
event_callback(
@ -71,6 +229,38 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_START,
data={"engine": "test", "metadata": {}},
)
)
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_START,
data={
"engine": "test",
"language": hass.config.language,
"intent_input": "fake-text",
"conversation_id": None,
"device_id": None,
},
)
)
assert satellite.state == AssistSatelliteState.PROCESSING
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
@ -83,6 +273,21 @@ async def test_pipeline(
)
)
# Fake tts result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_START,
data={
"engine": "test",
"language": hass.config.language,
"voice": "test",
"tts_input": "fake-text",
},
)
)
assert satellite.state == AssistSatelliteState.RESPONDING
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
@ -91,6 +296,18 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.RUN_END
)
)
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
done.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
@ -100,102 +317,56 @@ async def test_pipeline(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
)
rtp_protocol.transport = Mock()
satellite._tones = Tones(0)
satellite.transport = Mock()
satellite.connection_made(satellite.transport)
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
# Ensure audio queue is cleared before pipeline starts
rtp_protocol._audio_queue.put_nowait(bad_chunk)
satellite._audio_queue.put_nowait(bad_chunk)
def send_audio(*args, **kwargs):
# Test finished successfully
done.set()
# Don't send audio
pass
rtp_protocol.send_audio = Mock(side_effect=send_audio)
satellite.send_audio = Mock(side_effect=send_audio)
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes aggressive VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
"""Test timeout during pipeline run."""
assert await async_setup_component(hass, "voip", {})
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
await asyncio.sleep(10)
with (
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
return_value=True,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
pipeline_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
await done.wait()
# Finished speaking
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
async def test_stt_stream_timeout(
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
) -> None:
"""Test timeout in STT stream during pipeline run."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
pass
with patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
audio_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
satellite._tones = Tones(0)
satellite._audio_chunk_timeout = 0.001
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
satellite.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
async def test_tts_timeout(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -278,15 +448,7 @@ async def test_tts_timeout(
tone_bytes = bytes([1, 2, 3, 4])
def send_audio(audio_bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
# Block here to force a timeout in _send_tts
time.sleep(2)
async def async_send_audio(audio_bytes, **kwargs):
async def async_send_audio(audio_bytes: bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
@ -303,37 +465,22 @@ async def test_tts_timeout(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
tts_extra_timeout=0.001,
listening_tone_enabled=True,
processing_tone_enabled=True,
error_tone_enabled=True,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
)
rtp_protocol._tone_bytes = tone_bytes
rtp_protocol._processing_bytes = tone_bytes
rtp_protocol._error_bytes = tone_bytes
rtp_protocol.transport = Mock()
rtp_protocol.send_audio = Mock()
satellite._tts_extra_timeout = 0.001
for tone in Tones:
satellite._tone_bytes[tone] = tone_bytes
original_send_tts = rtp_protocol._send_tts
satellite.transport = Mock()
satellite.send_audio = Mock()
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -342,17 +489,17 @@ async def test_tts_timeout(
done.set()
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -361,26 +508,34 @@ async def test_tts_timeout(
async def test_tts_wrong_extension(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
async def test_tts_wrong_wav_format(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
async def test_empty_tts_output(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -605,37 +754,78 @@ async def test_empty_tts_output(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to finish
async with asyncio.timeout(1):
await rtp_protocol._tts_done.wait()
await satellite._tts_done.wait()
mock_send_tts.assert_not_called()
async def test_pipeline_error(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pipeline error causes the error tone to be played."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
played_audio_bytes = b""
async def async_pipeline_from_audio_stream(*args, **kwargs):
# Fake error
event_callback = kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.ERROR,
data={"code": "error-code", "message": "error message"},
)
)
async def async_send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be error.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
satellite._tones = Tones.ERROR
satellite.transport = Mock()
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for error tone to be played
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -3,8 +3,8 @@
from unittest.mock import Mock, patch
from homeassistant.components import assist_pipeline
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.pipeline import PipelineData
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
timer_handle.cancel()
timer_handle2.cancel()
timer_handle3.cancel()
async def test_queue_to_iterable() -> None:
"""Test queue_to_iterable."""
queue: asyncio.Queue[int | None] = asyncio.Queue()
expected_items = list(range(10))
for i in expected_items:
await queue.put(i)
# Will terminate the stream
await queue.put(None)
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
assert expected_items == actual_items
# Check timeout
assert queue.empty()
# Time out on first item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
await asyncio.sleep(1)
# Check timeout on second item
assert queue.empty()
await queue.put(12345)
# Time out on second item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
if item != 12345:
await asyncio.sleep(1)
assert queue.empty()