Compare commits
8 commits
dev
...
synesthesi
Author | SHA1 | Date | |
---|---|---|---|
|
898bb56519 | ||
|
1a6affc426 | ||
|
93cc266b06 | ||
|
f0c49b3995 | ||
|
d375bfaefe | ||
|
7fe4a52d59 | ||
|
a51de1df3c | ||
|
644427ecc7 |
36 changed files with 2224 additions and 672 deletions
|
@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/aseko_pool_live/ @milanmeu
|
||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/homeassistant/components/assist_satellite/ @synesthesiam
|
||||
/tests/components/assist_satellite/ @synesthesiam
|
||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||
/homeassistant/components/atag/ @MatsNL
|
||||
|
|
|
@ -16,6 +16,7 @@ from .const import (
|
|||
DATA_LAST_WAKE_UP,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
OPTION_PREFERRED,
|
||||
SAMPLE_CHANNELS,
|
||||
SAMPLE_RATE,
|
||||
SAMPLE_WIDTH,
|
||||
|
@ -57,6 +58,7 @@ __all__ = (
|
|||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
"EVENT_RECORDING",
|
||||
"OPTION_PREFERRED",
|
||||
"SAMPLES_PER_CHUNK",
|
||||
"SAMPLE_RATE",
|
||||
"SAMPLE_WIDTH",
|
||||
|
|
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
|||
MS_PER_CHUNK = 10
|
||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
|
|
@ -504,7 +504,7 @@ class AudioSettings:
|
|||
is_vad_enabled: bool = True
|
||||
"""True if VAD is used to determine the end of the voice command."""
|
||||
|
||||
silence_seconds: float = 0.5
|
||||
silence_seconds: float = 0.7
|
||||
"""Seconds of silence after voice command has ended."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -906,6 +906,8 @@ class PipelineRun:
|
|||
metadata,
|
||||
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||
)
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
raise # expected
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
raise SpeechToTextError(
|
||||
|
|
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, OPTION_PREFERRED
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
||||
@callback
|
||||
def get_chosen_pipeline(
|
||||
|
|
65
homeassistant/components/assist_satellite/__init__.py
Normal file
65
homeassistant/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
"""Base class for assist satellite entities."""
|
||||
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteState",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityDescription",
|
||||
"AssistSatelliteEntityFeature",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
await component.async_setup(config)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
component.async_register_entity_service(
|
||||
"announce",
|
||||
vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Optional("text"): str,
|
||||
vol.Optional("media"): str,
|
||||
}
|
||||
),
|
||||
cv.has_at_least_one_key("text", "media"),
|
||||
),
|
||||
"async_annonuce",
|
||||
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
3
homeassistant/components/assist_satellite/const.py
Normal file
3
homeassistant/components/assist_satellite/const.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Constants for assist satellite."""
|
||||
|
||||
DOMAIN = "assist_satellite"
|
283
homeassistant/components/assist_satellite/entity.py
Normal file
283
homeassistant/components/assist_satellite/entity.py
Normal file
|
@ -0,0 +1,283 @@
|
|||
"""Assist satellite entity."""
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final
|
||||
|
||||
from homeassistant.components import media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .errors import SatelliteBusyError
|
||||
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
|
||||
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||
"""A class that describes assist satellite entities."""
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
entity_description: AssistSatelliteEntityDescription
|
||||
_attr_should_poll = False
|
||||
_attr_state: AssistSatelliteState | None = None
|
||||
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_is_announcing: bool = False
|
||||
_tts_finished_event: asyncio.Event | None = None
|
||||
_wake_word_future: asyncio.Future[str | None] | None = None
|
||||
|
||||
@property
|
||||
def is_announcing(self) -> bool:
|
||||
"""Returns true if currently announcing."""
|
||||
return self._is_announcing
|
||||
|
||||
async def async_announce(
|
||||
self,
|
||||
text: str | None = None,
|
||||
media_id: str | None = None,
|
||||
) -> None:
|
||||
"""Play an announcement on the satellite.
|
||||
|
||||
If media_id is not provided, text is synthesized to
|
||||
audio with the selected pipeline.
|
||||
|
||||
Calls _internal_async_announce with media id and expects it to block
|
||||
until the announcement is completed.
|
||||
"""
|
||||
if text is None:
|
||||
text = ""
|
||||
|
||||
if not media_id:
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
text,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
media_id,
|
||||
None,
|
||||
)
|
||||
media_id = media.url
|
||||
|
||||
# Resolve to full URL
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
||||
self._is_announcing = True
|
||||
|
||||
try:
|
||||
# Block until announcement is finished
|
||||
await self._internal_async_announce(media_id)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
|
||||
async def _internal_async_announce(self, media_id: str) -> None:
|
||||
"""Announce the media URL on the satellite and returns when finished."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_intercepting_wake_word(self) -> bool:
|
||||
"""Return true if next wake word will be intercepted."""
|
||||
return (self._wake_word_future is not None) and (
|
||||
not self._wake_word_future.cancelled()
|
||||
)
|
||||
|
||||
async def async_intercept_wake_word(self) -> str | None:
|
||||
"""Intercept the next wake word from the satellite.
|
||||
|
||||
Returns the detected wake word phrase or None.
|
||||
"""
|
||||
if self._wake_word_future is not None:
|
||||
raise SatelliteBusyError
|
||||
|
||||
# Will cause next wake word to be intercepted in
|
||||
# _async_accept_pipeline_from_satellite
|
||||
self._wake_word_future = asyncio.Future()
|
||||
|
||||
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
||||
|
||||
try:
|
||||
return await self._wake_word_future
|
||||
finally:
|
||||
self._wake_word_future = None
|
||||
|
||||
return None
|
||||
|
||||
async def _async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
pipeline_entity_id: str | None = None,
|
||||
vad_sensitivity_entity_id: str | None = None,
|
||||
wake_word_phrase: str | None = None,
|
||||
) -> None:
|
||||
"""Trigger an Assist pipeline in Home Assistant from a satellite."""
|
||||
if self.is_intercepting_wake_word:
|
||||
# Intercepting wake word and immediately end pipeline
|
||||
_LOGGER.debug(
|
||||
"Intercepted wake word: %s (entity_id=%s)",
|
||||
wake_word_phrase,
|
||||
self.entity_id,
|
||||
)
|
||||
assert self._wake_word_future is not None
|
||||
self._wake_word_future.set_result(wake_word_phrase)
|
||||
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||
return
|
||||
|
||||
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
|
||||
|
||||
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||
if vad_sensitivity_entity_id:
|
||||
if (
|
||||
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
||||
) is None:
|
||||
raise ValueError("VAD sensitivity entity not found")
|
||||
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||
|
||||
# Refresh context if necessary
|
||||
if (
|
||||
(self._context is None)
|
||||
or (self._context_set is None)
|
||||
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||
):
|
||||
self.async_set_context(Context())
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = ulid.ulid()
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._tts_finished_event = None
|
||||
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=pipeline_id,
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output="wav",
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
||||
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type is PipelineEventType.WAKE_WORD_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
elif event.type is PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._tts_finished_event = asyncio.Event()
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
elif event.type is PipelineEventType.RUN_END:
|
||||
if self._tts_finished_event is None:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
self.on_pipeline_event(event)
|
||||
|
||||
def _set_state(self, state: AssistSatelliteState):
|
||||
"""Set the entity's state."""
|
||||
self._attr_state = state
|
||||
self.async_write_ha_state()
|
||||
|
||||
def tts_response_finished(self) -> None:
|
||||
"""Tell entity that the text-to-speech response has finished playing."""
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
if self._tts_finished_event is not None:
|
||||
self._tts_finished_event.set()
|
||||
|
||||
def _resolve_pipeline(self, pipeline_entity_id: str | None) -> str | None:
|
||||
"""Resolve pipeline from select entity to id."""
|
||||
if not pipeline_entity_id:
|
||||
return None
|
||||
|
||||
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
|
||||
raise ValueError("Pipeline entity not found")
|
||||
|
||||
if pipeline_entity_state.state != OPTION_PREFERRED:
|
||||
# Resolve pipeline by name
|
||||
for pipeline in async_get_pipelines(self.hass):
|
||||
if pipeline.name == pipeline_entity_state.state:
|
||||
return pipeline.id
|
||||
|
||||
return None
|
11
homeassistant/components/assist_satellite/errors.py
Normal file
11
homeassistant/components/assist_satellite/errors.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""Errors for assist satellite."""
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
|
||||
class AssistSatelliteError(HomeAssistantError):
|
||||
"""Base class for assist satellite errors."""
|
||||
|
||||
|
||||
class SatelliteBusyError(AssistSatelliteError):
|
||||
"""Satellite is busy and cannot handle the request."""
|
7
homeassistant/components/assist_satellite/icons.json
Normal file
7
homeassistant/components/assist_satellite/icons.json
Normal file
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"default": "mdi:microphone-message"
|
||||
}
|
||||
}
|
||||
}
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"domain": "assist_satellite",
|
||||
"name": "Assist Satellite",
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": false,
|
||||
"dependencies": ["assist_pipeline", "stt", "tts"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||
"integration_type": "entity"
|
||||
}
|
26
homeassistant/components/assist_satellite/models.py
Normal file
26
homeassistant/components/assist_satellite/models.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""Models for assist satellite."""
|
||||
|
||||
from enum import IntFlag, StrEnum
|
||||
|
||||
|
||||
class AssistSatelliteState(StrEnum):
|
||||
"""Valid states of an Assist satellite entity."""
|
||||
|
||||
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||
|
||||
LISTENING_COMMAND = "listening_command"
|
||||
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||
|
||||
PROCESSING = "processing"
|
||||
"""Home Assistant is processing the voice command."""
|
||||
|
||||
RESPONDING = "responding"
|
||||
"""Device is speaking the response."""
|
||||
|
||||
|
||||
class AssistSatelliteEntityFeature(IntFlag):
|
||||
"""Supported features of Assist satellite entity."""
|
||||
|
||||
ANNOUNCE = 1
|
||||
"""Device supports remotely triggered announcements."""
|
13
homeassistant/components/assist_satellite/strings.json
Normal file
13
homeassistant/components/assist_satellite/strings.json
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"name": "Assist satellite",
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
42
homeassistant/components/assist_satellite/websocket_api.py
Normal file
42
homeassistant/components/assist_satellite/websocket_api.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
"""Assist satellite Websocket API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_satellite/intercept_wake_word",
|
||||
vol.Required("entity_id"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Intercept the next wake word from a satellite."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
satellite = component.get_entity(msg["entity_id"])
|
||||
if satellite is None:
|
||||
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
|
||||
return
|
||||
|
||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
509
homeassistant/components/esphome/assist_satellite.py
Normal file
509
homeassistant/components/esphome/assist_satellite.py
Normal file
|
@ -0,0 +1,509 @@
|
|||
"""Support for assist satellites in ESPHome."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
from typing import Any, cast
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantFeature,
|
||||
VoiceAssistantTimerEventType,
|
||||
)
|
||||
|
||||
from homeassistant.components import assist_satellite, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
)
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import EntityCategory, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import EsphomeAssistEntity
|
||||
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||
VoiceAssistantEventType, PipelineEventType
|
||||
] = EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
|
||||
}
|
||||
)
|
||||
|
||||
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
|
||||
EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
entry: ESPHomeConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Assist satellite entity."""
|
||||
entry_data = entry.runtime_data
|
||||
assert entry_data.device_info is not None
|
||||
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
entry_data.api_version
|
||||
):
|
||||
async_add_entities(
|
||||
[
|
||||
EsphomeAssistSatellite(hass, entry, entry_data),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EsphomeAssistSatellite(
|
||||
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
|
||||
):
|
||||
"""Satellite running ESPHome."""
|
||||
|
||||
entity_description = assist_satellite.AssistSatelliteEntityDescription(
|
||||
key="assist_satellite",
|
||||
translation_key="assist_satellite",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
entry_data: RuntimeEntryData,
|
||||
) -> None:
|
||||
"""Initialize satellite."""
|
||||
super().__init__(entry_data)
|
||||
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.entry_data = entry_data
|
||||
self.cli = self.entry_data.client
|
||||
|
||||
self._is_running: bool = True
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._tts_streaming_task: asyncio.Task | None = None
|
||||
self._udp_server: VoiceAssistantUDPServer | None = None
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
assert self.entry_data.device_info is not None
|
||||
feature_flags = (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if feature_flags & VoiceAssistantFeature.API_AUDIO:
|
||||
# TCP audio
|
||||
self.entry_data.disconnect_callbacks.add(
|
||||
self.cli.subscribe_voice_assistant(
|
||||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
handle_audio=self.handle_audio,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# UDP audio
|
||||
self.entry_data.disconnect_callbacks.add(
|
||||
self.cli.subscribe_voice_assistant(
|
||||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
)
|
||||
)
|
||||
|
||||
if feature_flags & VoiceAssistantFeature.TIMERS:
|
||||
# Device supports timers
|
||||
assert (self.registry_entry is not None) and (
|
||||
self.registry_entry.device_id is not None
|
||||
)
|
||||
self.entry_data.disconnect_callbacks.add(
|
||||
async_register_timer_handler(
|
||||
self.hass, self.registry_entry.device_id, self.handle_timer_event
|
||||
)
|
||||
)
|
||||
|
||||
if feature_flags & VoiceAssistantFeature.ANNOUNCE:
|
||||
# Device supports announcements
|
||||
self._attr_supported_features |= (
|
||||
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
self._is_running = False
|
||||
self._stop_pipeline()
|
||||
|
||||
async def _internal_async_announce(self, media_id: str) -> None:
|
||||
self.cli.send_voice_assistant_announce(media_id)
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
try:
|
||||
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
||||
except KeyError:
|
||||
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
|
||||
return
|
||||
|
||||
data_to_send: dict[str, Any] = {}
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
||||
}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["tts_input"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert event.data is not None
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
path = tts_output["url"]
|
||||
url = async_process_play_media_url(self.hass, path)
|
||||
data_to_send = {"url": url}
|
||||
|
||||
assert self.entry_data.device_info is not None
|
||||
feature_flags = (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||
media_id = tts_output["media_id"]
|
||||
self._tts_streaming_task = (
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._stream_tts_audio(media_id),
|
||||
"esphome_voice_assistant_tts",
|
||||
)
|
||||
)
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||
assert event.data is not None
|
||||
if not event.data["wake_word_output"]:
|
||||
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||
data_to_send = {
|
||||
"code": "no_wake_word",
|
||||
"message": "No wake word detected",
|
||||
}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"code": event.data["code"],
|
||||
"message": event.data["message"],
|
||||
}
|
||||
|
||||
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
||||
|
||||
async def handle_pipeline_start(
|
||||
self,
|
||||
conversation_id: str,
|
||||
flags: int,
|
||||
audio_settings: VoiceAssistantAudioSettings,
|
||||
wake_word_phrase: str | None,
|
||||
) -> int | None:
|
||||
"""Handle pipeline run request."""
|
||||
# Clear audio queue
|
||||
while not self._audio_queue.empty():
|
||||
await self._audio_queue.get()
|
||||
|
||||
if self._tts_streaming_task is not None:
|
||||
# Cancel current TTS response
|
||||
self._tts_streaming_task.cancel()
|
||||
self._tts_streaming_task = None
|
||||
|
||||
# API or UDP output audio
|
||||
port: int = 0
|
||||
assert self.entry_data.device_info is not None
|
||||
feature_flags = (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
|
||||
feature_flags & VoiceAssistantFeature.API_AUDIO
|
||||
):
|
||||
port = await self._start_udp_server()
|
||||
_LOGGER.debug("Started UDP server on port %s", port)
|
||||
|
||||
# Get entity ids for pipeline and finished speaking detection
|
||||
ent_reg = er.async_get(self.hass)
|
||||
pipeline_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self.entry_data.device_info.mac_address}-pipeline",
|
||||
)
|
||||
vad_sensitivity_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
||||
)
|
||||
|
||||
# Device triggered pipeline (wake word, etc.)
|
||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||
start_stage = PipelineStage.WAKE_WORD
|
||||
else:
|
||||
start_stage = PipelineStage.STT
|
||||
|
||||
end_stage = PipelineStage.TTS
|
||||
|
||||
# Run the pipeline
|
||||
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._async_accept_pipeline_from_satellite(
|
||||
audio_stream=self._wrap_audio_stream(),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
pipeline_entity_id=pipeline_entity_id,
|
||||
vad_sensitivity_entity_id=vad_sensitivity_entity_id,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
),
|
||||
"esphome_assist_satellite_pipeline",
|
||||
)
|
||||
self._pipeline_task.add_done_callback(
|
||||
lambda _future: self.handle_pipeline_finished()
|
||||
)
|
||||
|
||||
return port
|
||||
|
||||
async def handle_audio(self, data: bytes) -> None:
|
||||
"""Handle incoming audio chunk from API."""
|
||||
self._audio_queue.put_nowait(data)
|
||||
|
||||
async def handle_pipeline_stop(self) -> None:
|
||||
"""Handle request for pipeline to stop."""
|
||||
self._stop_pipeline()
|
||||
|
||||
def handle_pipeline_finished(self) -> None:
|
||||
"""Handle when pipeline has finished running."""
|
||||
self.entry_data.async_set_assist_pipeline_state(False)
|
||||
self._stop_udp_server()
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
def handle_timer_event(
|
||||
self, event_type: TimerEventType, timer_info: TimerInfo
|
||||
) -> None:
|
||||
"""Handle timer events."""
|
||||
try:
|
||||
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
|
||||
except KeyError:
|
||||
_LOGGER.debug("Received unknown timer event type: %s", event_type)
|
||||
return
|
||||
|
||||
self.cli.send_voice_assistant_timer_event(
|
||||
native_event_type,
|
||||
timer_info.id,
|
||||
timer_info.name,
|
||||
timer_info.created_seconds,
|
||||
timer_info.seconds_left,
|
||||
timer_info.is_active,
|
||||
)
|
||||
|
||||
async def _stream_tts_audio(
|
||||
self,
|
||||
media_id: str,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
sample_channels: int = 1,
|
||||
samples_per_chunk: int = 512,
|
||||
) -> None:
|
||||
"""Stream TTS audio chunks to device via API or UDP."""
|
||||
self.cli.send_voice_assistant_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
|
||||
)
|
||||
|
||||
try:
|
||||
if not self._is_running:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
if (
|
||||
(wav_file.getframerate() != sample_rate)
|
||||
or (wav_file.getsampwidth() != sample_width)
|
||||
or (wav_file.getnchannels() != sample_channels)
|
||||
):
|
||||
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
|
||||
|
||||
while True:
|
||||
chunk = wav_file.readframes(samples_per_chunk)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if self._udp_server is not None:
|
||||
self._udp_server.send_audio_bytes(chunk)
|
||||
else:
|
||||
self.cli.send_voice_assistant_audio(chunk)
|
||||
|
||||
# Wait for 90% of the duration of the audio that was
|
||||
# sent for it to be played. This will overrun the
|
||||
# device's buffer for very long audio, so using a media
|
||||
# player is preferred.
|
||||
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||
except asyncio.CancelledError:
|
||||
return # Don't trigger state change
|
||||
finally:
|
||||
self.cli.send_voice_assistant_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||
)
|
||||
|
||||
# State change
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks from the queue until None."""
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
def _stop_pipeline(self) -> None:
|
||||
"""Request pipeline to be stopped."""
|
||||
self._audio_queue.put_nowait(None)
|
||||
_LOGGER.debug("Requested pipeline stop")
|
||||
|
||||
async def _start_udp_server(self) -> int:
|
||||
"""Start a UDP server on a random free port."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind(("", 0)) # random free port
|
||||
|
||||
(
|
||||
_transport,
|
||||
protocol,
|
||||
) = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
|
||||
)
|
||||
|
||||
assert isinstance(protocol, VoiceAssistantUDPServer)
|
||||
self._udp_server = protocol
|
||||
|
||||
# Return port
|
||||
return cast(int, sock.getsockname()[1])
|
||||
|
||||
def _stop_udp_server(self) -> None:
|
||||
"""Stop the UDP server if it's running."""
|
||||
if self._udp_server is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._udp_server.close()
|
||||
finally:
|
||||
self._udp_server = None
|
||||
|
||||
_LOGGER.debug("Stopped UDP server")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
"""Receive UDP packets and forward them to the audio queue."""
|
||||
|
||||
transport: asyncio.DatagramTransport | None = None
|
||||
remote_addr: tuple[str, int] | None = None
|
||||
|
||||
def __init__(
|
||||
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize protocol."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._audio_queue = audio_queue
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Store transport for later use."""
|
||||
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Handle incoming UDP packet."""
|
||||
if self.remote_addr is None:
|
||||
self.remote_addr = addr
|
||||
|
||||
self._audio_queue.put_nowait(data)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
"""Handle when a send or receive operation raises an OSError.
|
||||
|
||||
(Other than BlockingIOError or InterruptedError.)
|
||||
"""
|
||||
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||
|
||||
# Stop pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the receiver."""
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
self.remote_addr = None
|
||||
|
||||
def send_audio_bytes(self, data: bytes) -> None:
|
||||
"""Send bytes to the device via UDP."""
|
||||
if self.transport is None:
|
||||
_LOGGER.error("No transport to send audio to")
|
||||
return
|
||||
|
||||
if self.remote_addr is None:
|
||||
_LOGGER.error("No address to send audio to")
|
||||
return
|
||||
|
||||
self.transport.sendto(data, self.remote_addr)
|
|
@ -27,12 +27,12 @@ from awesomeversion import AwesomeVersion
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import tag, zeroconf
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_ID,
|
||||
CONF_MODE,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_LOGGING_CHANGED,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
Event,
|
||||
|
@ -77,7 +77,6 @@ from .voice_assistant import (
|
|||
VoiceAssistantAPIPipeline,
|
||||
VoiceAssistantPipeline,
|
||||
VoiceAssistantUDPPipeline,
|
||||
handle_timer_event,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -500,29 +499,14 @@ class ESPHomeManager:
|
|||
)
|
||||
)
|
||||
|
||||
flags = device_info.voice_assistant_feature_flags_compat(api_version)
|
||||
if flags:
|
||||
if flags & VoiceAssistantFeature.API_AUDIO:
|
||||
entry_data.disconnect_callbacks.add(
|
||||
cli.subscribe_voice_assistant(
|
||||
handle_start=self._handle_pipeline_start,
|
||||
handle_stop=self._handle_pipeline_stop,
|
||||
handle_audio=self._handle_audio,
|
||||
)
|
||||
)
|
||||
else:
|
||||
entry_data.disconnect_callbacks.add(
|
||||
cli.subscribe_voice_assistant(
|
||||
handle_start=self._handle_pipeline_start,
|
||||
handle_stop=self._handle_pipeline_stop,
|
||||
)
|
||||
)
|
||||
if flags & VoiceAssistantFeature.TIMERS:
|
||||
entry_data.disconnect_callbacks.add(
|
||||
async_register_timer_handler(
|
||||
hass, self.device_id, partial(handle_timer_event, cli)
|
||||
)
|
||||
)
|
||||
if device_info.voice_assistant_feature_flags_compat(api_version) and (
|
||||
Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms
|
||||
):
|
||||
# Create assist satellite entity
|
||||
await self.hass.config_entries.async_forward_entry_setups(
|
||||
self.entry, [Platform.ASSIST_SATELLITE]
|
||||
)
|
||||
entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE)
|
||||
|
||||
cli.subscribe_states(entry_data.async_update_state)
|
||||
cli.subscribe_service_calls(self.async_on_service_call)
|
||||
|
@ -844,4 +828,5 @@ async def cleanup_instance(
|
|||
cleanup_callback()
|
||||
await data.async_cleanup()
|
||||
await data.client.disconnect()
|
||||
|
||||
return data
|
||||
|
|
|
@ -59,6 +59,17 @@
|
|||
}
|
||||
},
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"name": "[%key:component::assist_satellite::entity_component::_::name%]",
|
||||
"state": {
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"binary_sensor": {
|
||||
"assist_in_progress": {
|
||||
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
|
||||
|
|
|
@ -20,6 +20,7 @@ from .devices import VoIPDevices
|
|||
from .voip import HassVoipDatagramProtocol
|
||||
|
||||
PLATFORMS = (
|
||||
Platform.ASSIST_SATELLITE,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
|
|
306
homeassistant/components/voip/assist_satellite.py
Normal file
306
homeassistant/components/voip/assist_satellite.py
Normal file
|
@ -0,0 +1,306 @@
|
|||
"""Assist satellite entity for VoIP integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import IntFlag
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
import wave
|
||||
|
||||
from voip_utils import RtpDatagramProtocol
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.async_ import queue_to_iterable
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DomainData
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||
|
||||
|
||||
class Tones(IntFlag):
|
||||
"""Feedback tones for specific events."""
|
||||
|
||||
LISTENING = 1
|
||||
PROCESSING = 2
|
||||
ERROR = 4
|
||||
|
||||
|
||||
_TONE_FILENAMES: dict[Tones, str] = {
|
||||
Tones.LISTENING: "tone.pcm",
|
||||
Tones.PROCESSING: "processing.pcm",
|
||||
Tones.ERROR: "error.pcm",
|
||||
}
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP Assist satellite entity."""
|
||||
domain_data: DomainData = hass.data[DOMAIN]
|
||||
|
||||
@callback
|
||||
def async_add_device(device: VoIPDevice) -> None:
|
||||
"""Add device."""
|
||||
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
|
||||
|
||||
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||
|
||||
entities: list[VoIPEntity] = [
|
||||
VoipAssistSatellite(hass, device, config_entry)
|
||||
for device in domain_data.devices
|
||||
]
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
|
||||
"""Assist satellite for VoIP devices."""
|
||||
|
||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||
_attr_translation_key = "assist_satellite"
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
voip_device: VoIPDevice,
|
||||
config_entry: ConfigEntry,
|
||||
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
|
||||
) -> None:
|
||||
"""Initialize an Assist satellite."""
|
||||
VoIPEntity.__init__(self, voip_device)
|
||||
AssistSatelliteEntity.__init__(self)
|
||||
RtpDatagramProtocol.__init__(self)
|
||||
|
||||
self.config_entry = config_entry
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._audio_chunk_timeout: float = 2.0
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._pipeline_had_error: bool = False
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_extra_timeout: float = 1.0
|
||||
self._tone_bytes: dict[Tones, bytes] = {}
|
||||
self._tones = tones
|
||||
self._processing_tone_done = asyncio.Event()
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
self.voip_device.protocol = self
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
assert self.voip_device.protocol == self
|
||||
self.voip_device.protocol = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# VoIP
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||
|
||||
try:
|
||||
self._tts_done.clear()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
||||
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
|
||||
audio_stream=queue_to_iterable(
|
||||
self._audio_queue, timeout=self._audio_chunk_timeout
|
||||
),
|
||||
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
|
||||
self.hass
|
||||
),
|
||||
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
|
||||
self.hass
|
||||
),
|
||||
)
|
||||
|
||||
if self._pipeline_had_error:
|
||||
self._pipeline_had_error = False
|
||||
await self._play_tone(Tones.ERROR)
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline cancelled or timed out")
|
||||
self.disconnect()
|
||||
self._clear_audio_queue()
|
||||
finally:
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
"""Ensure audio queue is empty."""
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type == PipelineEventType.STT_END:
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
self._processing_tone_done.clear()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
|
||||
)
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||
self._pipeline_had_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return # not connected
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
# Don't overlap TTS and processing beep
|
||||
await self._processing_tone_done.wait()
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
# Update satellite state
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
|
||||
"""Play a tone as feedback to the user if it's enabled."""
|
||||
if (self._tones & tone) != tone:
|
||||
return # not enabled
|
||||
|
||||
if tone not in self._tone_bytes:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
_TONE_FILENAMES[tone],
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes[tone],
|
||||
silence_before=silence_before,
|
||||
)
|
||||
|
||||
if tone == Tones.PROCESSING:
|
||||
self._processing_tone_done.set()
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
|
@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
|
|||
"""Call when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
|
||||
self.async_on_remove(
|
||||
self.voip_device.async_listen_update(self._is_active_changed)
|
||||
)
|
||||
|
||||
@callback
|
||||
def _is_active_changed(self, device: VoIPDevice) -> None:
|
||||
"""Call when active state changed."""
|
||||
self._attr_is_on = self._device.is_active
|
||||
self._attr_is_on = self.voip_device.is_active
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from voip_utils import CallInfo
|
||||
from voip_utils import CallInfo, VoipDatagramProtocol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
|
@ -22,6 +22,7 @@ class VoIPDevice:
|
|||
device_id: str
|
||||
is_active: bool = False
|
||||
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
|
||||
protocol: VoipDatagramProtocol | None = None
|
||||
|
||||
@callback
|
||||
def set_is_active(self, active: bool) -> None:
|
||||
|
@ -56,6 +57,18 @@ class VoIPDevice:
|
|||
|
||||
return False
|
||||
|
||||
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for pipeline select."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
|
||||
|
||||
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for VAD sensitivity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
|
||||
)
|
||||
|
||||
|
||||
class VoIPDevices:
|
||||
"""Class to store devices."""
|
||||
|
|
|
@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
|
|||
_attr_has_entity_name = True
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, device: VoIPDevice) -> None:
|
||||
def __init__(self, voip_device: VoIPDevice) -> None:
|
||||
"""Initialize VoIP entity."""
|
||||
self._device = device
|
||||
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
|
||||
self.voip_device = voip_device
|
||||
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, device.voip_id)},
|
||||
identifiers={(DOMAIN, voip_device.voip_id)},
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Voice over IP",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline"],
|
||||
"dependencies": ["assist_pipeline", "assist_satellite"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
|
|
|
@ -10,6 +10,17 @@
|
|||
}
|
||||
},
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"name": "[%key:component::assist_satellite::entity_component::_::name%]",
|
||||
"state": {
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"binary_sensor": {
|
||||
"call_in_progress": {
|
||||
"name": "Call in progress"
|
||||
|
|
|
@ -3,15 +3,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import wave
|
||||
|
||||
from voip_utils import (
|
||||
CallInfo,
|
||||
|
@ -21,33 +17,19 @@ from voip_utils import (
|
|||
VoipDatagramProtocol,
|
||||
)
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
async_get_pipeline,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.audio_enhancer import (
|
||||
AudioEnhancer,
|
||||
MicroVadEnhancer,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VadSensitivity,
|
||||
VoiceCommandSegmenter,
|
||||
)
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .devices import VoIPDevice, VoIPDevices
|
||||
from .devices import VoIPDevices
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,11 +42,8 @@ def make_protocol(
|
|||
) -> VoipDatagramProtocol:
|
||||
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
||||
voip_device = devices.async_get_or_create(call_info)
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
|
||||
try:
|
||||
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||
except PipelineNotFound:
|
||||
|
@ -83,22 +62,18 @@ def make_protocol(
|
|||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
if (protocol := voip_device.protocol) is None:
|
||||
raise ValueError("VoIP satellite not found")
|
||||
|
||||
# Pipeline is properly configured
|
||||
return PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(user_id=devices.config_entry.data["user"]),
|
||||
opus_payload_type=call_info.opus_payload_type,
|
||||
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
|
||||
protocol.rtcp_state = rtcp_state
|
||||
if protocol.rtcp_state is not None:
|
||||
# Automatically disconnect when BYE is received over RTCP
|
||||
protocol.rtcp_state.bye_callback = protocol.disconnect
|
||||
|
||||
return protocol
|
||||
|
||||
|
||||
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||
|
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
|||
await self._closed_event.wait()
|
||||
|
||||
|
||||
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
language: str,
|
||||
voip_device: VoIPDevice,
|
||||
context: Context,
|
||||
opus_payload_type: int,
|
||||
pipeline_timeout: float = 30.0,
|
||||
audio_timeout: float = 2.0,
|
||||
buffered_chunks_before_speech: int = 100,
|
||||
listening_tone_enabled: bool = True,
|
||||
processing_tone_enabled: bool = True,
|
||||
error_tone_enabled: bool = True,
|
||||
tone_delay: float = 0.2,
|
||||
tts_extra_timeout: float = 1.0,
|
||||
silence_seconds: float = 1.0,
|
||||
rtcp_state: RtcpState | None = None,
|
||||
) -> None:
|
||||
"""Set up pipeline RTP server."""
|
||||
super().__init__(
|
||||
rate=RATE,
|
||||
width=WIDTH,
|
||||
channels=CHANNELS,
|
||||
opus_payload_type=opus_payload_type,
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
self.hass = hass
|
||||
self.language = language
|
||||
self.voip_device = voip_device
|
||||
self.pipeline: Pipeline | None = None
|
||||
self.pipeline_timeout = pipeline_timeout
|
||||
self.audio_timeout = audio_timeout
|
||||
self.buffered_chunks_before_speech = buffered_chunks_before_speech
|
||||
self.listening_tone_enabled = listening_tone_enabled
|
||||
self.processing_tone_enabled = processing_tone_enabled
|
||||
self.error_tone_enabled = error_tone_enabled
|
||||
self.tone_delay = tone_delay
|
||||
self.tts_extra_timeout = tts_extra_timeout
|
||||
self.silence_seconds = silence_seconds
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._context = context
|
||||
self._conversation_id: str | None = None
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._tts_done = asyncio.Event()
|
||||
self._session_id: str | None = None
|
||||
self._tone_bytes: bytes | None = None
|
||||
self._processing_bytes: bytes | None = None
|
||||
self._error_bytes: bytes | None = None
|
||||
self._pipeline_error: bool = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Server is ready."""
|
||||
super().connection_made(transport)
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Handle connection is lost or closed."""
|
||||
super().connection_lost(exc)
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.hass.async_create_background_task(
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
if self._session_id is None:
|
||||
self._session_id = ulid_now()
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
if self.listening_tone_enabled:
|
||||
await self._play_listening_tone()
|
||||
|
||||
try:
|
||||
# Wait for speech before starting pipeline
|
||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||
audio_enhancer = MicroVadEnhancer(0, 0, True)
|
||||
chunk_buffer: deque[bytes] = deque(
|
||||
maxlen=self.buffered_chunks_before_speech,
|
||||
)
|
||||
speech_detected = await self._wait_for_speech(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
)
|
||||
if not speech_detected:
|
||||
_LOGGER.debug("No speech detected")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
self._tts_done.clear()
|
||||
|
||||
async def stt_stream():
|
||||
try:
|
||||
async for chunk in self._segment_audio(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if self.processing_tone_enabled:
|
||||
await self._play_processing_tone()
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Audio timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
async with asyncio.timeout(self.pipeline_timeout):
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream(),
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.voip_device.voip_id
|
||||
),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=self.voip_device.device_id,
|
||||
tts_audio_output="wav",
|
||||
)
|
||||
|
||||
if self._pipeline_error:
|
||||
self._pipeline_error = False
|
||||
if self.error_tone_enabled:
|
||||
await self._play_error_tone()
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
async def _wait_for_speech(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: MutableSequence[bytes],
|
||||
):
|
||||
"""Buffer audio chunks until speech is detected.
|
||||
|
||||
Returns True if speech was detected, False otherwise.
|
||||
"""
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
chunk_buffer.append(chunk)
|
||||
|
||||
segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
)
|
||||
if segmenter.in_command:
|
||||
# Buffer until command starts
|
||||
if len(vad_buffer) > 0:
|
||||
chunk_buffer.append(vad_buffer.bytes())
|
||||
|
||||
return True
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
return False
|
||||
|
||||
async def _segment_audio(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: Sequence[bytes],
|
||||
) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks until voice command has finished."""
|
||||
# Buffered chunks first
|
||||
for buffered_chunk in chunk_buffer:
|
||||
yield buffered_chunk
|
||||
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def _event_callback(self, event: PipelineEvent):
|
||||
if not event.data:
|
||||
return
|
||||
|
||||
if event.type == PipelineEventType.INTENT_END:
|
||||
# Capture conversation id
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS
|
||||
self._pipeline_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_listening_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is listening."""
|
||||
if self._tone_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"tone.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes,
|
||||
silence_before=self.tone_delay,
|
||||
)
|
||||
|
||||
async def _play_processing_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is processing the voice command."""
|
||||
if self._processing_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._processing_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"processing.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._processing_bytes)
|
||||
|
||||
async def _play_error_tone(self) -> None:
|
||||
"""Play a tone to indicate a pipeline error occurred."""
|
||||
if self._error_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._error_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"error.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._error_bytes)
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
||||
|
||||
|
||||
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
||||
"""Plays a pre-recorded message on a loop."""
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
|||
|
||||
AIR_QUALITY = "air_quality"
|
||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||
ASSIST_SATELLITE = "assist_satellite"
|
||||
BINARY_SENSOR = "binary_sensor"
|
||||
BUTTON = "button"
|
||||
CALENDAR = "calendar"
|
||||
|
|
|
@ -5,22 +5,28 @@ from __future__ import annotations
|
|||
from asyncio import (
|
||||
AbstractEventLoop,
|
||||
Future,
|
||||
Queue,
|
||||
Semaphore,
|
||||
Task,
|
||||
TimerHandle,
|
||||
gather,
|
||||
get_running_loop,
|
||||
timeout as async_timeout,
|
||||
)
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
|
||||
|
||||
_DataT = TypeVar("_DataT", default=Any)
|
||||
|
||||
|
||||
def create_eager_task[_T](
|
||||
coro: Coroutine[Any, Any, _T],
|
||||
|
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
|
|||
"""Return a list of scheduled TimerHandles."""
|
||||
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
|
||||
return handles
|
||||
|
||||
|
||||
async def queue_to_iterable(
|
||||
queue: Queue[_DataT], timeout: float | None = None
|
||||
) -> AsyncIterable[_DataT]:
|
||||
"""Stream items from a queue until None with an optional timeout per item."""
|
||||
if timeout is None:
|
||||
while (item := await queue.get()) is not None:
|
||||
yield item
|
||||
else:
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
||||
while item is not None:
|
||||
yield item
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
|
1
tests/components/assist_satellite/__init__.py
Normal file
1
tests/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""Tests for Assist Satellite."""
|
106
tests/components/assist_satellite/conftest.py
Normal file
106
tests/components/assist_satellite/conftest.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
"""Test helpers for Assist Satellite."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
DOMAIN as AS_DOMAIN,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityFeature,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test_satellite"
|
||||
|
||||
|
||||
class MockAssistSatellite(AssistSatelliteEntity):
|
||||
"""Mock Assist Satellite Entity."""
|
||||
|
||||
_attr_name = "Test Entity"
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
self.events = []
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity() -> MockAssistSatellite:
|
||||
"""Mock Assist Satellite Entity."""
|
||||
return MockAssistSatellite()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Mock config entry."""
|
||||
entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Initialize components."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test tts platform via config entry."""
|
||||
async_add_entities([entity])
|
||||
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
88
tests/components/assist_satellite/test_entity.py
Normal file
88
tests/components/assist_satellite/test_entity.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
"""Test the Assist Satellite entity."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteState
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
ENTITY_ID = "assist_satellite.test_entity"
|
||||
|
||||
|
||||
async def test_entity_state(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test entity state represent events."""
|
||||
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
context = Context()
|
||||
|
||||
audio_stream = object()
|
||||
|
||||
entity.async_set_context(context)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||
) as mock_start_pipeline:
|
||||
await entity._async_accept_pipeline_from_satellite(audio_stream) # type: ignore[arg-type]
|
||||
|
||||
assert mock_start_pipeline.called
|
||||
kwargs = mock_start_pipeline.call_args[1]
|
||||
assert kwargs["context"] is context
|
||||
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
|
||||
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
assert kwargs["stt_stream"] is audio_stream
|
||||
assert kwargs["pipeline_id"] is None
|
||||
assert kwargs["device_id"] is None
|
||||
assert kwargs["tts_audio_output"] == "wav"
|
||||
assert kwargs["wake_word_phrase"] is None
|
||||
assert kwargs["audio_settings"] == AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||
)
|
||||
assert kwargs["start_stage"] == PipelineStage.STT
|
||||
assert kwargs["end_stage"] == PipelineStage.TTS
|
||||
|
||||
for event_type, expected_state in (
|
||||
(PipelineEventType.RUN_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.RUN_END, AssistSatelliteState.RESPONDING),
|
||||
):
|
||||
kwargs["event_callback"](PipelineEvent(event_type, {}))
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == expected_state, event_type
|
||||
|
||||
entity.tts_response_finished()
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
181
tests/components/assist_satellite/test_websocket_api.py
Normal file
181
tests/components/assist_satellite/test_websocket_api.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
"""Test the Assist Satellite websocket API."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteEntityFeature
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
ENTITY_ID = "assist_satellite.test_entity"
|
||||
|
||||
|
||||
async def audio_stream() -> AsyncIterable[bytes]:
|
||||
"""Empty audio stream."""
|
||||
yield b""
|
||||
|
||||
|
||||
async def test_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/intercept_wake_word command."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.validate",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_recognize_intent",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_text_to_speech",
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(entity, "on_pipeline_event") as mock_on_pipeline_event,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{"type": "assist_satellite/intercept_wake_word", "entity_id": ENTITY_ID}
|
||||
)
|
||||
|
||||
# Wait for interception to start
|
||||
while not entity.is_intercepting_wake_word:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Start a pipeline with a wake word
|
||||
await entity._async_accept_pipeline_from_satellite(
|
||||
audio_stream=audio_stream(),
|
||||
start_stage=PipelineStage.STT,
|
||||
end_stage=PipelineStage.TTS,
|
||||
wake_word_phrase="test wake word",
|
||||
)
|
||||
|
||||
# Verify that wake word was intercepted
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == {"wake_word_phrase": "test wake word"}
|
||||
|
||||
# Verify that only run end event was sent to pipeline
|
||||
mock_on_pipeline_event.assert_called_once_with(
|
||||
PipelineEvent(PipelineEventType.RUN_END, data=None, timestamp=ANY)
|
||||
)
|
||||
|
||||
|
||||
async def test_announce_not_supported(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/announce command with an entity that doesn't support announcements."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch.object(
|
||||
entity, "_attr_supported_features", AssistSatelliteEntityFeature(0)
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/announce",
|
||||
"entity_id": ENTITY_ID,
|
||||
"media_id": "test media id",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == ERR_NOT_SUPPORTED
|
||||
|
||||
|
||||
async def test_announce_media_id(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/announce command with media id."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
entity, "_internal_async_announce"
|
||||
) as mock_internal_async_announce,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/announce",
|
||||
"entity_id": ENTITY_ID,
|
||||
"media_id": "test media id",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
# Verify media id was passed through
|
||||
mock_internal_async_announce.assert_called_once_with("test media id")
|
||||
|
||||
|
||||
async def test_announce_text(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/announce command with text."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(url="test media id", mime_type=""),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
|
||||
return_value="test media id",
|
||||
),
|
||||
patch.object(
|
||||
entity, "_internal_async_announce"
|
||||
) as mock_internal_async_announce,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/announce",
|
||||
"entity_id": ENTITY_ID,
|
||||
"text": "test text",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
# Verify media id was passed through
|
||||
mock_internal_async_announce.assert_called_once_with("test media id")
|
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
File diff suppressed because one or more lines are too long
|
@ -3,15 +3,26 @@
|
|||
import asyncio
|
||||
import io
|
||||
from pathlib import Path
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from voip_utils import CallInfo
|
||||
|
||||
from homeassistant.components import assist_pipeline, voip
|
||||
from homeassistant.components.voip.devices import VoIPDevice
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, voip
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.components.voip import HassVoipDatagramProtocol
|
||||
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
||||
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
||||
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, Platform
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
|
|||
return wav_io.getvalue()
|
||||
|
||||
|
||||
def async_get_satellite_entity(
|
||||
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||
) -> AssistSatelliteEntity | None:
|
||||
"""Get Assist satellite entity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
|
||||
)
|
||||
if satellite_entity_id is None:
|
||||
return None
|
||||
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||
assist_satellite.DOMAIN
|
||||
]
|
||||
return component.get_entity(satellite_entity_id)
|
||||
|
||||
|
||||
async def test_is_valid_call(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that a call is now allowed from an unknown device."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol = HassVoipDatagramProtocol(hass, voip_devices)
|
||||
assert not protocol.is_valid_call(call_info)
|
||||
|
||||
ent_reg = er.async_get(hass)
|
||||
allowed_call_entity_id = ent_reg.async_get_entity_id(
|
||||
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
|
||||
)
|
||||
assert allowed_call_entity_id is not None
|
||||
state = hass.states.get(allowed_call_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
|
||||
# Allow calls
|
||||
hass.states.async_set(allowed_call_entity_id, STATE_ON)
|
||||
assert protocol.is_valid_call(call_info)
|
||||
|
||||
|
||||
async def test_calls_not_allowed(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when calls aren't allowed."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
# Test the playback
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
def send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be problem.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
protocol.transport = Mock()
|
||||
protocol.loop_delay = 0
|
||||
with patch.object(protocol, "send_audio", send_audio):
|
||||
protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
||||
|
||||
async def test_pipeline_not_found(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when a pipeline isn't found."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
|
||||
):
|
||||
protocol: PreRecordMessageProtocol = make_protocol(
|
||||
hass, voip_devices, call_info
|
||||
)
|
||||
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
|
||||
async def test_satellite_prepared(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that satellite is prepared for a call."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
pipeline = assist_pipeline.Pipeline(
|
||||
conversation_engine="test",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test",
|
||||
stt_engine="test",
|
||||
stt_language="en",
|
||||
tts_engine="test",
|
||||
tts_language="en",
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline",
|
||||
return_value=pipeline,
|
||||
),
|
||||
):
|
||||
protocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert protocol == satellite
|
||||
|
||||
|
||||
async def test_pipeline(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that pipeline function is called from RTP protocol."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
voip_user_id = satellite.config_entry.data["user"]
|
||||
assert voip_user_id
|
||||
|
||||
return 0
|
||||
# Satellite is muted until a call begins
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
# Used to test that audio queue is cleared before pipeline starts
|
||||
bad_chunk = bytes([1, 2, 3, 4])
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
|
||||
):
|
||||
assert context.user_id == voip_user_id
|
||||
assert device_id == voip_device.device_id
|
||||
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
assert _chunk != bad_chunk
|
||||
assert chunk != bad_chunk
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Test empty data
|
||||
event_callback(
|
||||
|
@ -71,6 +229,38 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_START,
|
||||
data={"engine": "test", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "fake-text",
|
||||
"conversation_id": None,
|
||||
"device_id": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -83,6 +273,21 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# Fake tts result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"voice": "test",
|
||||
"tts_input": "fake-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Proceed with media output
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -91,6 +296,18 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.RUN_END
|
||||
)
|
||||
)
|
||||
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
done.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
|
@ -100,102 +317,56 @@ async def test_pipeline(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite._tones = Tones(0)
|
||||
satellite.transport = Mock()
|
||||
|
||||
satellite.connection_made(satellite.transport)
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
# Ensure audio queue is cleared before pipeline starts
|
||||
rtp_protocol._audio_queue.put_nowait(bad_chunk)
|
||||
satellite._audio_queue.put_nowait(bad_chunk)
|
||||
|
||||
def send_audio(*args, **kwargs):
|
||||
# Test finished successfully
|
||||
done.set()
|
||||
# Don't send audio
|
||||
pass
|
||||
|
||||
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
||||
satellite.send_audio = Mock(side_effect=send_audio)
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes aggressive VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
"""Test timeout during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
pipeline_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
# Finished speaking
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
async def test_stt_stream_timeout(
|
||||
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
|
||||
) -> None:
|
||||
"""Test timeout in STT stream during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
|
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
pass
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
audio_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
satellite._tones = Tones(0)
|
||||
satellite._audio_chunk_timeout = 0.001
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
satellite.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
|
||||
async def test_tts_timeout(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will time out based on its length."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -278,15 +448,7 @@ async def test_tts_timeout(
|
|||
|
||||
tone_bytes = bytes([1, 2, 3, 4])
|
||||
|
||||
def send_audio(audio_bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
||||
# Block here to force a timeout in _send_tts
|
||||
time.sleep(2)
|
||||
|
||||
async def async_send_audio(audio_bytes, **kwargs):
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
@ -303,37 +465,22 @@ async def test_tts_timeout(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
tts_extra_timeout=0.001,
|
||||
listening_tone_enabled=True,
|
||||
processing_tone_enabled=True,
|
||||
error_tone_enabled=True,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
|
||||
)
|
||||
rtp_protocol._tone_bytes = tone_bytes
|
||||
rtp_protocol._processing_bytes = tone_bytes
|
||||
rtp_protocol._error_bytes = tone_bytes
|
||||
rtp_protocol.transport = Mock()
|
||||
rtp_protocol.send_audio = Mock()
|
||||
satellite._tts_extra_timeout = 0.001
|
||||
for tone in Tones:
|
||||
satellite._tone_bytes[tone] = tone_bytes
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
satellite.transport = Mock()
|
||||
satellite.send_audio = Mock()
|
||||
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -342,17 +489,17 @@ async def test_tts_timeout(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -361,26 +508,34 @@ async def test_tts_timeout(
|
|||
|
||||
async def test_tts_wrong_extension(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
|
|||
|
||||
async def test_tts_wrong_wav_format(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio with a specific format."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
async def test_empty_tts_output(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will not stream when output is empty."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -605,37 +754,78 @@ async def test_empty_tts_output(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||
) as mock_send_tts,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to finish
|
||||
async with asyncio.timeout(1):
|
||||
await rtp_protocol._tts_done.wait()
|
||||
await satellite._tts_done.wait()
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_pipeline_error(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pipeline error causes the error tone to be played."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
# Fake error
|
||||
event_callback = kwargs["event_callback"]
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.ERROR,
|
||||
data={"code": "error-code", "message": "error message"},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be error.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
satellite._tones = Tones.ERROR
|
||||
satellite.transport = Mock()
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for error tone to be played
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.wyoming import DOMAIN
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from homeassistant.components import assist_pipeline
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
|
|||
timer_handle.cancel()
|
||||
timer_handle2.cancel()
|
||||
timer_handle3.cancel()
|
||||
|
||||
|
||||
async def test_queue_to_iterable() -> None:
|
||||
"""Test queue_to_iterable."""
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
expected_items = list(range(10))
|
||||
|
||||
for i in expected_items:
|
||||
await queue.put(i)
|
||||
|
||||
# Will terminate the stream
|
||||
await queue.put(None)
|
||||
|
||||
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
|
||||
|
||||
assert expected_items == actual_items
|
||||
|
||||
# Check timeout
|
||||
assert queue.empty()
|
||||
|
||||
# Time out on first item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check timeout on second item
|
||||
assert queue.empty()
|
||||
await queue.put(12345)
|
||||
|
||||
# Time out on second item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
if item != 12345:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert queue.empty()
|
||||
|
|
Loading…
Add table
Reference in a new issue