Add Assist satellite entity + VoIP (#123830)
* Add assist_satellite and implement VoIP * Fix tests * More tests * Improve test * Update entity state * Set state correctly * Move more functionality into base class * Move RTP protocol into entity * Fix tests * Remove string * Move to util method * Align states better with pipeline events * Remove public async_get_satellite_entity * WAITING_FOR_WAKE_WORD * Pass entity ids for pipeline/vad sensitivity * Remove connect/disconnect * Clean up * Final cleanup
This commit is contained in:
parent
36bfd7b9ce
commit
644427ecc7
26 changed files with 1089 additions and 647 deletions
|
@ -143,6 +143,7 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/aseko_pool_live/ @milanmeu
|
||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/homeassistant/components/assist_satellite/ @synesthesiam
|
||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||
/homeassistant/components/atag/ @MatsNL
|
||||
|
|
|
@ -16,6 +16,7 @@ from .const import (
|
|||
DATA_LAST_WAKE_UP,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
OPTION_PREFERRED,
|
||||
SAMPLE_CHANNELS,
|
||||
SAMPLE_RATE,
|
||||
SAMPLE_WIDTH,
|
||||
|
@ -57,6 +58,7 @@ __all__ = (
|
|||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
"EVENT_RECORDING",
|
||||
"OPTION_PREFERRED",
|
||||
"SAMPLES_PER_CHUNK",
|
||||
"SAMPLE_RATE",
|
||||
"SAMPLE_WIDTH",
|
||||
|
|
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
|||
MS_PER_CHUNK = 10
|
||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
|
|
@ -504,7 +504,7 @@ class AudioSettings:
|
|||
is_vad_enabled: bool = True
|
||||
"""True if VAD is used to determine the end of the voice command."""
|
||||
|
||||
silence_seconds: float = 0.5
|
||||
silence_seconds: float = 0.7
|
||||
"""Seconds of silence after voice command has ended."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -906,6 +906,8 @@ class PipelineRun:
|
|||
metadata,
|
||||
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||
)
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
raise # expected
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
raise SpeechToTextError(
|
||||
|
|
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, OPTION_PREFERRED
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
||||
@callback
|
||||
def get_chosen_pipeline(
|
||||
|
|
44
homeassistant/components/assist_satellite/__init__.py
Normal file
44
homeassistant/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
"""Base class for assist satellite entities."""
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
from .models import AssistSatelliteState
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteState",
|
||||
"AssistSatelliteEntity",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
await component.async_setup(config)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
3
homeassistant/components/assist_satellite/const.py
Normal file
3
homeassistant/components/assist_satellite/const.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Constants for assist satellite."""
|
||||
|
||||
DOMAIN = "assist_satellite"
|
151
homeassistant/components/assist_satellite/entity.py
Normal file
151
homeassistant/components/assist_satellite/entity.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
"""Assist satellite entity."""
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
import time
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipelines,
|
||||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.const import EntityCategory
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .models import AssistSatelliteState
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
entity_description = EntityDescription(
|
||||
key="assist_satellite",
|
||||
translation_key="assist_satellite",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_should_poll = False
|
||||
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_run_has_tts: bool = False
|
||||
|
||||
async def _async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
pipeline_entity_id: str | None = None,
|
||||
vad_sensitivity_entity_id: str | None = None,
|
||||
wake_word_phrase: str | None = None,
|
||||
) -> None:
|
||||
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||
pipeline_id: str | None = None
|
||||
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||
|
||||
if pipeline_entity_id:
|
||||
# Resolve pipeline by name
|
||||
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
|
||||
if (pipeline_entity_state is not None) and (
|
||||
pipeline_entity_state.state != OPTION_PREFERRED
|
||||
):
|
||||
for pipeline in async_get_pipelines(self.hass):
|
||||
if pipeline.name == pipeline_entity_state.state:
|
||||
pipeline_id = pipeline.id
|
||||
break
|
||||
|
||||
if vad_sensitivity_entity_id:
|
||||
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
|
||||
if vad_sensitivity_state is not None:
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
device_id: str | None = None
|
||||
if self.registry_entry is not None:
|
||||
device_id = self.registry_entry.device_id
|
||||
|
||||
# Refresh context if necessary
|
||||
if (
|
||||
(self._context is None)
|
||||
or (self._context_set is None)
|
||||
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||
):
|
||||
self.async_set_context(Context())
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = ulid.ulid()
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._run_has_tts = False
|
||||
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self.on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=pipeline_id,
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output="wav",
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
|
||||
),
|
||||
)
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type == PipelineEventType.WAKE_WORD_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
elif event.type == PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||
elif event.type == PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type == PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
elif event.type == PipelineEventType.RUN_END:
|
||||
if not self._run_has_tts:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
def _set_state(self, state: AssistSatelliteState):
|
||||
"""Set the entity's state."""
|
||||
self._attr_state = state
|
||||
self.async_write_ha_state()
|
||||
|
||||
def tts_response_finished(self) -> None:
|
||||
"""Tell entity that the text-to-speech response has finished playing."""
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"domain": "assist_satellite",
|
||||
"name": "Assist Satellite",
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": false,
|
||||
"dependencies": ["assist_pipeline", "stt"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||
"integration_type": "entity"
|
||||
}
|
19
homeassistant/components/assist_satellite/models.py
Normal file
19
homeassistant/components/assist_satellite/models.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
"""Models for assist satellite."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AssistSatelliteState(StrEnum):
|
||||
"""Valid states of an Assist satellite entity."""
|
||||
|
||||
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||
|
||||
LISTENING_COMMAND = "listening_command"
|
||||
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||
|
||||
PROCESSING = "processing"
|
||||
"""Home Assistant is processing the voice command."""
|
||||
|
||||
RESPONDING = "responding"
|
||||
"""Device is speaking the response."""
|
14
homeassistant/components/assist_satellite/strings.json
Normal file
14
homeassistant/components/assist_satellite/strings.json
Normal file
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ from .devices import VoIPDevices
|
|||
from .voip import HassVoipDatagramProtocol
|
||||
|
||||
PLATFORMS = (
|
||||
Platform.ASSIST_SATELLITE,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
|
|
298
homeassistant/components/voip/assist_satellite.py
Normal file
298
homeassistant/components/voip/assist_satellite.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
"""Assist satellite entity for VoIP integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import IntFlag
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
import wave
|
||||
|
||||
from voip_utils import RtpDatagramProtocol
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteEntity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.async_ import queue_to_iterable
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DomainData
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||
|
||||
|
||||
class Tones(IntFlag):
|
||||
"""Feedback tones for specific events."""
|
||||
|
||||
LISTENING = 1
|
||||
PROCESSING = 2
|
||||
ERROR = 4
|
||||
|
||||
|
||||
_TONE_FILENAMES: dict[Tones, str] = {
|
||||
Tones.LISTENING: "tone.pcm",
|
||||
Tones.PROCESSING: "processing.pcm",
|
||||
Tones.ERROR: "error.pcm",
|
||||
}
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP Assist satellite entity."""
|
||||
domain_data: DomainData = hass.data[DOMAIN]
|
||||
|
||||
@callback
|
||||
def async_add_device(device: VoIPDevice) -> None:
|
||||
"""Add device."""
|
||||
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
|
||||
|
||||
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||
|
||||
entities: list[VoIPEntity] = [
|
||||
VoipAssistSatellite(hass, device, config_entry)
|
||||
for device in domain_data.devices
|
||||
]
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
|
||||
"""Assist satellite for VoIP devices."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
voip_device: VoIPDevice,
|
||||
config_entry: ConfigEntry,
|
||||
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
|
||||
) -> None:
|
||||
"""Initialize an Assist satellite."""
|
||||
VoIPEntity.__init__(self, voip_device)
|
||||
AssistSatelliteEntity.__init__(self)
|
||||
RtpDatagramProtocol.__init__(self)
|
||||
|
||||
self.config_entry = config_entry
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._audio_chunk_timeout: float = 2.0
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._pipeline_had_error: bool = False
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_extra_timeout: float = 1.0
|
||||
self._tone_bytes: dict[Tones, bytes] = {}
|
||||
self._tones = tones
|
||||
self._processing_tone_done = asyncio.Event()
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
self.voip_device.protocol = self
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
if self.voip_device.protocol == self:
|
||||
self.voip_device.protocol = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# VoIP
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||
|
||||
try:
|
||||
self._tts_done.clear()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
||||
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
|
||||
audio_stream=queue_to_iterable(
|
||||
self._audio_queue, timeout=self._audio_chunk_timeout
|
||||
),
|
||||
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
|
||||
self.hass
|
||||
),
|
||||
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
|
||||
self.hass
|
||||
),
|
||||
)
|
||||
|
||||
if self._pipeline_had_error:
|
||||
self._pipeline_had_error = False
|
||||
await self._play_tone(Tones.ERROR)
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline cancelled or timed out")
|
||||
self.disconnect()
|
||||
self._clear_audio_queue()
|
||||
finally:
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
"""Ensure audio queue is empty."""
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
super().on_pipeline_event(event)
|
||||
|
||||
if event.type == PipelineEventType.STT_END:
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
self._processing_tone_done.clear()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
|
||||
)
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||
self._pipeline_had_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return # not connected
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
# Don't overlap TTS and processing beep
|
||||
await self._processing_tone_done.wait()
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
# Update satellite state
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
|
||||
"""Play a tone as feedback to the user if it's enabled."""
|
||||
if (self._tones & tone) != tone:
|
||||
return # not enabled
|
||||
|
||||
if tone not in self._tone_bytes:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
_TONE_FILENAMES[tone],
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes[tone],
|
||||
silence_before=silence_before,
|
||||
)
|
||||
|
||||
if tone == Tones.PROCESSING:
|
||||
self._processing_tone_done.set()
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
|
@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
|
|||
"""Call when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
|
||||
self.async_on_remove(
|
||||
self.voip_device.async_listen_update(self._is_active_changed)
|
||||
)
|
||||
|
||||
@callback
|
||||
def _is_active_changed(self, device: VoIPDevice) -> None:
|
||||
"""Call when active state changed."""
|
||||
self._attr_is_on = self._device.is_active
|
||||
self._attr_is_on = self.voip_device.is_active
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from voip_utils import CallInfo
|
||||
from voip_utils import CallInfo, VoipDatagramProtocol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
|
@ -22,6 +22,7 @@ class VoIPDevice:
|
|||
device_id: str
|
||||
is_active: bool = False
|
||||
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
|
||||
protocol: VoipDatagramProtocol | None = None
|
||||
|
||||
@callback
|
||||
def set_is_active(self, active: bool) -> None:
|
||||
|
@ -56,6 +57,18 @@ class VoIPDevice:
|
|||
|
||||
return False
|
||||
|
||||
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for pipeline select."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
|
||||
|
||||
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for VAD sensitivity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
|
||||
)
|
||||
|
||||
|
||||
class VoIPDevices:
|
||||
"""Class to store devices."""
|
||||
|
|
|
@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
|
|||
_attr_has_entity_name = True
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, device: VoIPDevice) -> None:
|
||||
def __init__(self, voip_device: VoIPDevice) -> None:
|
||||
"""Initialize VoIP entity."""
|
||||
self._device = device
|
||||
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
|
||||
self.voip_device = voip_device
|
||||
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, device.voip_id)},
|
||||
identifiers={(DOMAIN, voip_device.voip_id)},
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Voice over IP",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline"],
|
||||
"dependencies": ["assist_pipeline", "assist_satellite"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
|
|
|
@ -10,6 +10,16 @@
|
|||
}
|
||||
},
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"binary_sensor": {
|
||||
"call_in_progress": {
|
||||
"name": "Call in progress"
|
||||
|
|
|
@ -3,15 +3,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import wave
|
||||
|
||||
from voip_utils import (
|
||||
CallInfo,
|
||||
|
@ -21,33 +17,19 @@ from voip_utils import (
|
|||
VoipDatagramProtocol,
|
||||
)
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
async_get_pipeline,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.audio_enhancer import (
|
||||
AudioEnhancer,
|
||||
MicroVadEnhancer,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VadSensitivity,
|
||||
VoiceCommandSegmenter,
|
||||
)
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .devices import VoIPDevice, VoIPDevices
|
||||
from .devices import VoIPDevices
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,11 +42,8 @@ def make_protocol(
|
|||
) -> VoipDatagramProtocol:
|
||||
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
||||
voip_device = devices.async_get_or_create(call_info)
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
|
||||
try:
|
||||
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||
except PipelineNotFound:
|
||||
|
@ -83,22 +62,18 @@ def make_protocol(
|
|||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
if (protocol := voip_device.protocol) is None:
|
||||
raise ValueError("VoIP satellite not found")
|
||||
|
||||
# Pipeline is properly configured
|
||||
return PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(user_id=devices.config_entry.data["user"]),
|
||||
opus_payload_type=call_info.opus_payload_type,
|
||||
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
|
||||
protocol.rtcp_state = rtcp_state
|
||||
if protocol.rtcp_state is not None:
|
||||
# Automatically disconnect when BYE is received over RTCP
|
||||
protocol.rtcp_state.bye_callback = protocol.disconnect
|
||||
|
||||
return protocol
|
||||
|
||||
|
||||
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||
|
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
|||
await self._closed_event.wait()
|
||||
|
||||
|
||||
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
language: str,
|
||||
voip_device: VoIPDevice,
|
||||
context: Context,
|
||||
opus_payload_type: int,
|
||||
pipeline_timeout: float = 30.0,
|
||||
audio_timeout: float = 2.0,
|
||||
buffered_chunks_before_speech: int = 100,
|
||||
listening_tone_enabled: bool = True,
|
||||
processing_tone_enabled: bool = True,
|
||||
error_tone_enabled: bool = True,
|
||||
tone_delay: float = 0.2,
|
||||
tts_extra_timeout: float = 1.0,
|
||||
silence_seconds: float = 1.0,
|
||||
rtcp_state: RtcpState | None = None,
|
||||
) -> None:
|
||||
"""Set up pipeline RTP server."""
|
||||
super().__init__(
|
||||
rate=RATE,
|
||||
width=WIDTH,
|
||||
channels=CHANNELS,
|
||||
opus_payload_type=opus_payload_type,
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
self.hass = hass
|
||||
self.language = language
|
||||
self.voip_device = voip_device
|
||||
self.pipeline: Pipeline | None = None
|
||||
self.pipeline_timeout = pipeline_timeout
|
||||
self.audio_timeout = audio_timeout
|
||||
self.buffered_chunks_before_speech = buffered_chunks_before_speech
|
||||
self.listening_tone_enabled = listening_tone_enabled
|
||||
self.processing_tone_enabled = processing_tone_enabled
|
||||
self.error_tone_enabled = error_tone_enabled
|
||||
self.tone_delay = tone_delay
|
||||
self.tts_extra_timeout = tts_extra_timeout
|
||||
self.silence_seconds = silence_seconds
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._context = context
|
||||
self._conversation_id: str | None = None
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._tts_done = asyncio.Event()
|
||||
self._session_id: str | None = None
|
||||
self._tone_bytes: bytes | None = None
|
||||
self._processing_bytes: bytes | None = None
|
||||
self._error_bytes: bytes | None = None
|
||||
self._pipeline_error: bool = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Server is ready."""
|
||||
super().connection_made(transport)
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Handle connection is lost or closed."""
|
||||
super().connection_lost(exc)
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.hass.async_create_background_task(
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
if self._session_id is None:
|
||||
self._session_id = ulid_now()
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
if self.listening_tone_enabled:
|
||||
await self._play_listening_tone()
|
||||
|
||||
try:
|
||||
# Wait for speech before starting pipeline
|
||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||
audio_enhancer = MicroVadEnhancer(0, 0, True)
|
||||
chunk_buffer: deque[bytes] = deque(
|
||||
maxlen=self.buffered_chunks_before_speech,
|
||||
)
|
||||
speech_detected = await self._wait_for_speech(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
)
|
||||
if not speech_detected:
|
||||
_LOGGER.debug("No speech detected")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
self._tts_done.clear()
|
||||
|
||||
async def stt_stream():
|
||||
try:
|
||||
async for chunk in self._segment_audio(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if self.processing_tone_enabled:
|
||||
await self._play_processing_tone()
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Audio timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
async with asyncio.timeout(self.pipeline_timeout):
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream(),
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.voip_device.voip_id
|
||||
),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=self.voip_device.device_id,
|
||||
tts_audio_output="wav",
|
||||
)
|
||||
|
||||
if self._pipeline_error:
|
||||
self._pipeline_error = False
|
||||
if self.error_tone_enabled:
|
||||
await self._play_error_tone()
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
async def _wait_for_speech(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: MutableSequence[bytes],
|
||||
):
|
||||
"""Buffer audio chunks until speech is detected.
|
||||
|
||||
Returns True if speech was detected, False otherwise.
|
||||
"""
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
chunk_buffer.append(chunk)
|
||||
|
||||
segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
)
|
||||
if segmenter.in_command:
|
||||
# Buffer until command starts
|
||||
if len(vad_buffer) > 0:
|
||||
chunk_buffer.append(vad_buffer.bytes())
|
||||
|
||||
return True
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
return False
|
||||
|
||||
async def _segment_audio(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: Sequence[bytes],
|
||||
) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks until voice command has finished."""
|
||||
# Buffered chunks first
|
||||
for buffered_chunk in chunk_buffer:
|
||||
yield buffered_chunk
|
||||
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def _event_callback(self, event: PipelineEvent):
|
||||
if not event.data:
|
||||
return
|
||||
|
||||
if event.type == PipelineEventType.INTENT_END:
|
||||
# Capture conversation id
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS
|
||||
self._pipeline_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_listening_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is listening."""
|
||||
if self._tone_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"tone.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes,
|
||||
silence_before=self.tone_delay,
|
||||
)
|
||||
|
||||
async def _play_processing_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is processing the voice command."""
|
||||
if self._processing_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._processing_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"processing.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._processing_bytes)
|
||||
|
||||
async def _play_error_tone(self) -> None:
|
||||
"""Play a tone to indicate a pipeline error occurred."""
|
||||
if self._error_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._error_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"error.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._error_bytes)
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
||||
|
||||
|
||||
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
||||
"""Plays a pre-recorded message on a loop."""
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
|||
|
||||
AIR_QUALITY = "air_quality"
|
||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||
ASSIST_SATELLITE = "assist_satellite"
|
||||
BINARY_SENSOR = "binary_sensor"
|
||||
BUTTON = "button"
|
||||
CALENDAR = "calendar"
|
||||
|
|
|
@ -5,22 +5,28 @@ from __future__ import annotations
|
|||
from asyncio import (
|
||||
AbstractEventLoop,
|
||||
Future,
|
||||
Queue,
|
||||
Semaphore,
|
||||
Task,
|
||||
TimerHandle,
|
||||
gather,
|
||||
get_running_loop,
|
||||
timeout as async_timeout,
|
||||
)
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
|
||||
|
||||
_DataT = TypeVar("_DataT", default=Any)
|
||||
|
||||
|
||||
def create_eager_task[_T](
|
||||
coro: Coroutine[Any, Any, _T],
|
||||
|
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
|
|||
"""Return a list of scheduled TimerHandles."""
|
||||
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
|
||||
return handles
|
||||
|
||||
|
||||
async def queue_to_iterable(
|
||||
queue: Queue[_DataT], timeout: float | None = None
|
||||
) -> AsyncIterable[_DataT]:
|
||||
"""Stream items from a queue until None with an optional timeout per item."""
|
||||
if timeout is None:
|
||||
while (item := await queue.get()) is not None:
|
||||
yield item
|
||||
else:
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
||||
while item is not None:
|
||||
yield item
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
|
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
File diff suppressed because one or more lines are too long
|
@ -3,15 +3,26 @@
|
|||
import asyncio
|
||||
import io
|
||||
from pathlib import Path
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from voip_utils import CallInfo
|
||||
|
||||
from homeassistant.components import assist_pipeline, voip
|
||||
from homeassistant.components.voip.devices import VoIPDevice
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, voip
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.components.voip import HassVoipDatagramProtocol
|
||||
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
||||
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
||||
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, Platform
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
|
|||
return wav_io.getvalue()
|
||||
|
||||
|
||||
def async_get_satellite_entity(
|
||||
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||
) -> AssistSatelliteEntity | None:
|
||||
"""Get Assist satellite entity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
|
||||
)
|
||||
if satellite_entity_id is None:
|
||||
return None
|
||||
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||
assist_satellite.DOMAIN
|
||||
]
|
||||
return component.get_entity(satellite_entity_id)
|
||||
|
||||
|
||||
async def test_is_valid_call(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that a call is now allowed from an unknown device."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol = HassVoipDatagramProtocol(hass, voip_devices)
|
||||
assert not protocol.is_valid_call(call_info)
|
||||
|
||||
ent_reg = er.async_get(hass)
|
||||
allowed_call_entity_id = ent_reg.async_get_entity_id(
|
||||
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
|
||||
)
|
||||
assert allowed_call_entity_id is not None
|
||||
state = hass.states.get(allowed_call_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
|
||||
# Allow calls
|
||||
hass.states.async_set(allowed_call_entity_id, STATE_ON)
|
||||
assert protocol.is_valid_call(call_info)
|
||||
|
||||
|
||||
async def test_calls_not_allowed(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when calls aren't allowed."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
# Test the playback
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
def send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be problem.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
protocol.transport = Mock()
|
||||
protocol.loop_delay = 0
|
||||
with patch.object(protocol, "send_audio", send_audio):
|
||||
protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
||||
|
||||
async def test_pipeline_not_found(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when a pipeline isn't found."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
|
||||
):
|
||||
protocol: PreRecordMessageProtocol = make_protocol(
|
||||
hass, voip_devices, call_info
|
||||
)
|
||||
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
|
||||
async def test_satellite_prepared(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that satellite is prepared for a call."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
pipeline = assist_pipeline.Pipeline(
|
||||
conversation_engine="test",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test",
|
||||
stt_engine="test",
|
||||
stt_language="en",
|
||||
tts_engine="test",
|
||||
tts_language="en",
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline",
|
||||
return_value=pipeline,
|
||||
),
|
||||
):
|
||||
protocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert protocol == satellite
|
||||
|
||||
|
||||
async def test_pipeline(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that pipeline function is called from RTP protocol."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
voip_user_id = satellite.config_entry.data["user"]
|
||||
assert voip_user_id
|
||||
|
||||
return 0
|
||||
# Satellite is muted until a call begins
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
# Used to test that audio queue is cleared before pipeline starts
|
||||
bad_chunk = bytes([1, 2, 3, 4])
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
|
||||
):
|
||||
assert context.user_id == voip_user_id
|
||||
assert device_id == voip_device.device_id
|
||||
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
assert _chunk != bad_chunk
|
||||
assert chunk != bad_chunk
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Test empty data
|
||||
event_callback(
|
||||
|
@ -71,6 +229,38 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_START,
|
||||
data={"engine": "test", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "fake-text",
|
||||
"conversation_id": None,
|
||||
"device_id": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -83,6 +273,21 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# Fake tts result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"voice": "test",
|
||||
"tts_input": "fake-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Proceed with media output
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -91,6 +296,18 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.RUN_END
|
||||
)
|
||||
)
|
||||
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
done.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
|
@ -100,102 +317,56 @@ async def test_pipeline(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite._tones = Tones(0)
|
||||
satellite.transport = Mock()
|
||||
|
||||
satellite.connection_made(satellite.transport)
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
# Ensure audio queue is cleared before pipeline starts
|
||||
rtp_protocol._audio_queue.put_nowait(bad_chunk)
|
||||
satellite._audio_queue.put_nowait(bad_chunk)
|
||||
|
||||
def send_audio(*args, **kwargs):
|
||||
# Test finished successfully
|
||||
done.set()
|
||||
# Don't send audio
|
||||
pass
|
||||
|
||||
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
||||
satellite.send_audio = Mock(side_effect=send_audio)
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes aggressive VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
"""Test timeout during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
pipeline_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
# Finished speaking
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
async def test_stt_stream_timeout(
|
||||
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
|
||||
) -> None:
|
||||
"""Test timeout in STT stream during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
|
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
pass
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
audio_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
satellite._tones = Tones(0)
|
||||
satellite._audio_chunk_timeout = 0.001
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
satellite.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
|
||||
async def test_tts_timeout(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will time out based on its length."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -278,15 +448,7 @@ async def test_tts_timeout(
|
|||
|
||||
tone_bytes = bytes([1, 2, 3, 4])
|
||||
|
||||
def send_audio(audio_bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
||||
# Block here to force a timeout in _send_tts
|
||||
time.sleep(2)
|
||||
|
||||
async def async_send_audio(audio_bytes, **kwargs):
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
@ -303,37 +465,22 @@ async def test_tts_timeout(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
tts_extra_timeout=0.001,
|
||||
listening_tone_enabled=True,
|
||||
processing_tone_enabled=True,
|
||||
error_tone_enabled=True,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
|
||||
)
|
||||
rtp_protocol._tone_bytes = tone_bytes
|
||||
rtp_protocol._processing_bytes = tone_bytes
|
||||
rtp_protocol._error_bytes = tone_bytes
|
||||
rtp_protocol.transport = Mock()
|
||||
rtp_protocol.send_audio = Mock()
|
||||
satellite._tts_extra_timeout = 0.001
|
||||
for tone in Tones:
|
||||
satellite._tone_bytes[tone] = tone_bytes
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
satellite.transport = Mock()
|
||||
satellite.send_audio = Mock()
|
||||
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -342,17 +489,17 @@ async def test_tts_timeout(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -361,26 +508,34 @@ async def test_tts_timeout(
|
|||
|
||||
async def test_tts_wrong_extension(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
|
|||
|
||||
async def test_tts_wrong_wav_format(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio with a specific format."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
async def test_empty_tts_output(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will not stream when output is empty."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -605,37 +754,78 @@ async def test_empty_tts_output(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||
) as mock_send_tts,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to finish
|
||||
async with asyncio.timeout(1):
|
||||
await rtp_protocol._tts_done.wait()
|
||||
await satellite._tts_done.wait()
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_pipeline_error(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pipeline error causes the error tone to be played."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
# Fake error
|
||||
event_callback = kwargs["event_callback"]
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.ERROR,
|
||||
data={"code": "error-code", "message": "error message"},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be error.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
satellite._tones = Tones.ERROR
|
||||
satellite.transport = Mock()
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for error tone to be played
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.wyoming import DOMAIN
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from homeassistant.components import assist_pipeline
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
|
|||
timer_handle.cancel()
|
||||
timer_handle2.cancel()
|
||||
timer_handle3.cancel()
|
||||
|
||||
|
||||
async def test_queue_to_iterable() -> None:
|
||||
"""Test queue_to_iterable."""
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
expected_items = list(range(10))
|
||||
|
||||
for i in expected_items:
|
||||
await queue.put(i)
|
||||
|
||||
# Will terminate the stream
|
||||
await queue.put(None)
|
||||
|
||||
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
|
||||
|
||||
assert expected_items == actual_items
|
||||
|
||||
# Check timeout
|
||||
assert queue.empty()
|
||||
|
||||
# Time out on first item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check timeout on second item
|
||||
assert queue.empty()
|
||||
await queue.put(12345)
|
||||
|
||||
# Time out on second item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
if item != 12345:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert queue.empty()
|
||||
|
|
Loading…
Add table
Reference in a new issue