Compare commits
8 commits
dev
...
synesthesi
Author | SHA1 | Date | |
---|---|---|---|
|
898bb56519 | ||
|
1a6affc426 | ||
|
93cc266b06 | ||
|
f0c49b3995 | ||
|
d375bfaefe | ||
|
7fe4a52d59 | ||
|
a51de1df3c | ||
|
644427ecc7 |
36 changed files with 2224 additions and 672 deletions
|
@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
|
||||||
/tests/components/aseko_pool_live/ @milanmeu
|
/tests/components/aseko_pool_live/ @milanmeu
|
||||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||||
|
/homeassistant/components/assist_satellite/ @synesthesiam
|
||||||
|
/tests/components/assist_satellite/ @synesthesiam
|
||||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||||
/homeassistant/components/atag/ @MatsNL
|
/homeassistant/components/atag/ @MatsNL
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .const import (
|
||||||
DATA_LAST_WAKE_UP,
|
DATA_LAST_WAKE_UP,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
EVENT_RECORDING,
|
EVENT_RECORDING,
|
||||||
|
OPTION_PREFERRED,
|
||||||
SAMPLE_CHANNELS,
|
SAMPLE_CHANNELS,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
SAMPLE_WIDTH,
|
SAMPLE_WIDTH,
|
||||||
|
@ -57,6 +58,7 @@ __all__ = (
|
||||||
"PipelineNotFound",
|
"PipelineNotFound",
|
||||||
"WakeWordSettings",
|
"WakeWordSettings",
|
||||||
"EVENT_RECORDING",
|
"EVENT_RECORDING",
|
||||||
|
"OPTION_PREFERRED",
|
||||||
"SAMPLES_PER_CHUNK",
|
"SAMPLES_PER_CHUNK",
|
||||||
"SAMPLE_RATE",
|
"SAMPLE_RATE",
|
||||||
"SAMPLE_WIDTH",
|
"SAMPLE_WIDTH",
|
||||||
|
|
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
||||||
MS_PER_CHUNK = 10
|
MS_PER_CHUNK = 10
|
||||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||||
|
|
||||||
|
OPTION_PREFERRED = "preferred"
|
||||||
|
|
|
@ -504,7 +504,7 @@ class AudioSettings:
|
||||||
is_vad_enabled: bool = True
|
is_vad_enabled: bool = True
|
||||||
"""True if VAD is used to determine the end of the voice command."""
|
"""True if VAD is used to determine the end of the voice command."""
|
||||||
|
|
||||||
silence_seconds: float = 0.5
|
silence_seconds: float = 0.7
|
||||||
"""Seconds of silence after voice command has ended."""
|
"""Seconds of silence after voice command has ended."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
@ -906,6 +906,8 @@ class PipelineRun:
|
||||||
metadata,
|
metadata,
|
||||||
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||||
)
|
)
|
||||||
|
except (asyncio.CancelledError, TimeoutError):
|
||||||
|
raise # expected
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||||
raise SpeechToTextError(
|
raise SpeechToTextError(
|
||||||
|
|
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN, OPTION_PREFERRED
|
||||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||||
from .vad import VadSensitivity
|
from .vad import VadSensitivity
|
||||||
|
|
||||||
OPTION_PREFERRED = "preferred"
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def get_chosen_pipeline(
|
def get_chosen_pipeline(
|
||||||
|
|
65
homeassistant/components/assist_satellite/__init__.py
Normal file
65
homeassistant/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
"""Base class for assist satellite entities."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||||
|
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||||
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DOMAIN",
|
||||||
|
"AssistSatelliteState",
|
||||||
|
"AssistSatelliteEntity",
|
||||||
|
"AssistSatelliteEntityDescription",
|
||||||
|
"AssistSatelliteEntityFeature",
|
||||||
|
]
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||||
|
_LOGGER, DOMAIN, hass
|
||||||
|
)
|
||||||
|
await component.async_setup(config)
|
||||||
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
|
component.async_register_entity_service(
|
||||||
|
"announce",
|
||||||
|
vol.All(
|
||||||
|
vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional("text"): str,
|
||||||
|
vol.Optional("media"): str,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
cv.has_at_least_one_key("text", "media"),
|
||||||
|
),
|
||||||
|
"async_annonuce",
|
||||||
|
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
"""Set up a config entry."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
return await component.async_setup_entry(entry)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
"""Unload a config entry."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
return await component.async_unload_entry(entry)
|
3
homeassistant/components/assist_satellite/const.py
Normal file
3
homeassistant/components/assist_satellite/const.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
"""Constants for assist satellite."""
|
||||||
|
|
||||||
|
DOMAIN = "assist_satellite"
|
283
homeassistant/components/assist_satellite/entity.py
Normal file
283
homeassistant/components/assist_satellite/entity.py
Normal file
|
@ -0,0 +1,283 @@
|
||||||
|
"""Assist satellite entity."""
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Final
|
||||||
|
|
||||||
|
from homeassistant.components import media_source, stt, tts
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
OPTION_PREFERRED,
|
||||||
|
AudioSettings,
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
async_get_pipeline,
|
||||||
|
async_get_pipelines,
|
||||||
|
async_pipeline_from_audio_stream,
|
||||||
|
vad,
|
||||||
|
)
|
||||||
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
|
from homeassistant.components.tts.media_source import (
|
||||||
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
|
)
|
||||||
|
from homeassistant.core import Context
|
||||||
|
from homeassistant.helpers import entity
|
||||||
|
from homeassistant.helpers.entity import EntityDescription
|
||||||
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
|
from .errors import SatelliteBusyError
|
||||||
|
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||||
|
"""A class that describes assist satellite entities."""
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteEntity(entity.Entity):
|
||||||
|
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||||
|
|
||||||
|
entity_description: AssistSatelliteEntityDescription
|
||||||
|
_attr_should_poll = False
|
||||||
|
_attr_state: AssistSatelliteState | None = None
|
||||||
|
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
||||||
|
|
||||||
|
_conversation_id: str | None = None
|
||||||
|
_conversation_id_time: float | None = None
|
||||||
|
|
||||||
|
_is_announcing: bool = False
|
||||||
|
_tts_finished_event: asyncio.Event | None = None
|
||||||
|
_wake_word_future: asyncio.Future[str | None] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_announcing(self) -> bool:
|
||||||
|
"""Returns true if currently announcing."""
|
||||||
|
return self._is_announcing
|
||||||
|
|
||||||
|
async def async_announce(
|
||||||
|
self,
|
||||||
|
text: str | None = None,
|
||||||
|
media_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Play an announcement on the satellite.
|
||||||
|
|
||||||
|
If media_id is not provided, text is synthesized to
|
||||||
|
audio with the selected pipeline.
|
||||||
|
|
||||||
|
Calls _internal_async_announce with media id and expects it to block
|
||||||
|
until the announcement is completed.
|
||||||
|
"""
|
||||||
|
if text is None:
|
||||||
|
text = ""
|
||||||
|
|
||||||
|
if not media_id:
|
||||||
|
# Synthesize audio and get URL
|
||||||
|
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
|
||||||
|
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||||
|
|
||||||
|
tts_options: dict[str, Any] = {}
|
||||||
|
if pipeline.tts_voice is not None:
|
||||||
|
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||||
|
|
||||||
|
media_id = tts_generate_media_source_id(
|
||||||
|
self.hass,
|
||||||
|
text,
|
||||||
|
engine=pipeline.tts_engine,
|
||||||
|
language=pipeline.tts_language,
|
||||||
|
options=tts_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
if media_source.is_media_source_id(media_id):
|
||||||
|
media = await media_source.async_resolve_media(
|
||||||
|
self.hass,
|
||||||
|
media_id,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
media_id = media.url
|
||||||
|
|
||||||
|
# Resolve to full URL
|
||||||
|
media_id = async_process_play_media_url(self.hass, media_id)
|
||||||
|
|
||||||
|
if self._is_announcing:
|
||||||
|
raise SatelliteBusyError
|
||||||
|
|
||||||
|
self._is_announcing = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Block until announcement is finished
|
||||||
|
await self._internal_async_announce(media_id)
|
||||||
|
finally:
|
||||||
|
self._is_announcing = False
|
||||||
|
|
||||||
|
async def _internal_async_announce(self, media_id: str) -> None:
|
||||||
|
"""Announce the media URL on the satellite and returns when finished."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_intercepting_wake_word(self) -> bool:
|
||||||
|
"""Return true if next wake word will be intercepted."""
|
||||||
|
return (self._wake_word_future is not None) and (
|
||||||
|
not self._wake_word_future.cancelled()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_intercept_wake_word(self) -> str | None:
|
||||||
|
"""Intercept the next wake word from the satellite.
|
||||||
|
|
||||||
|
Returns the detected wake word phrase or None.
|
||||||
|
"""
|
||||||
|
if self._wake_word_future is not None:
|
||||||
|
raise SatelliteBusyError
|
||||||
|
|
||||||
|
# Will cause next wake word to be intercepted in
|
||||||
|
# _async_accept_pipeline_from_satellite
|
||||||
|
self._wake_word_future = asyncio.Future()
|
||||||
|
|
||||||
|
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._wake_word_future
|
||||||
|
finally:
|
||||||
|
self._wake_word_future = None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _async_accept_pipeline_from_satellite(
|
||||||
|
self,
|
||||||
|
audio_stream: AsyncIterable[bytes],
|
||||||
|
start_stage: PipelineStage = PipelineStage.STT,
|
||||||
|
end_stage: PipelineStage = PipelineStage.TTS,
|
||||||
|
pipeline_entity_id: str | None = None,
|
||||||
|
vad_sensitivity_entity_id: str | None = None,
|
||||||
|
wake_word_phrase: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Trigger an Assist pipeline in Home Assistant from a satellite."""
|
||||||
|
if self.is_intercepting_wake_word:
|
||||||
|
# Intercepting wake word and immediately end pipeline
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Intercepted wake word: %s (entity_id=%s)",
|
||||||
|
wake_word_phrase,
|
||||||
|
self.entity_id,
|
||||||
|
)
|
||||||
|
assert self._wake_word_future is not None
|
||||||
|
self._wake_word_future.set_result(wake_word_phrase)
|
||||||
|
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||||
|
return
|
||||||
|
|
||||||
|
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
|
||||||
|
|
||||||
|
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||||
|
if vad_sensitivity_entity_id:
|
||||||
|
if (
|
||||||
|
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
||||||
|
) is None:
|
||||||
|
raise ValueError("VAD sensitivity entity not found")
|
||||||
|
|
||||||
|
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||||
|
|
||||||
|
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||||
|
|
||||||
|
# Refresh context if necessary
|
||||||
|
if (
|
||||||
|
(self._context is None)
|
||||||
|
or (self._context_set is None)
|
||||||
|
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||||
|
):
|
||||||
|
self.async_set_context(Context())
|
||||||
|
|
||||||
|
assert self._context is not None
|
||||||
|
|
||||||
|
# Reset conversation id if necessary
|
||||||
|
if (self._conversation_id_time is None) or (
|
||||||
|
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||||
|
):
|
||||||
|
self._conversation_id = None
|
||||||
|
|
||||||
|
if self._conversation_id is None:
|
||||||
|
self._conversation_id = ulid.ulid()
|
||||||
|
|
||||||
|
# Update timeout
|
||||||
|
self._conversation_id_time = time.monotonic()
|
||||||
|
|
||||||
|
# Set entity state based on pipeline events
|
||||||
|
self._tts_finished_event = None
|
||||||
|
|
||||||
|
await async_pipeline_from_audio_stream(
|
||||||
|
self.hass,
|
||||||
|
context=self._context,
|
||||||
|
event_callback=self._internal_on_pipeline_event,
|
||||||
|
stt_metadata=stt.SpeechMetadata(
|
||||||
|
language="", # set in async_pipeline_from_audio_stream
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
stt_stream=audio_stream,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
conversation_id=self._conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
|
tts_audio_output="wav",
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
audio_settings=AudioSettings(
|
||||||
|
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
|
||||||
|
),
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Handle pipeline events."""
|
||||||
|
|
||||||
|
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Set state based on pipeline stage."""
|
||||||
|
if event.type is PipelineEventType.WAKE_WORD_START:
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
elif event.type is PipelineEventType.STT_START:
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||||
|
elif event.type is PipelineEventType.INTENT_START:
|
||||||
|
self._set_state(AssistSatelliteState.PROCESSING)
|
||||||
|
elif event.type is PipelineEventType.TTS_START:
|
||||||
|
# Wait until tts_response_finished is called to return to waiting state
|
||||||
|
self._tts_finished_event = asyncio.Event()
|
||||||
|
self._set_state(AssistSatelliteState.RESPONDING)
|
||||||
|
elif event.type is PipelineEventType.RUN_END:
|
||||||
|
if self._tts_finished_event is None:
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
|
||||||
|
self.on_pipeline_event(event)
|
||||||
|
|
||||||
|
def _set_state(self, state: AssistSatelliteState):
|
||||||
|
"""Set the entity's state."""
|
||||||
|
self._attr_state = state
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
|
def tts_response_finished(self) -> None:
|
||||||
|
"""Tell entity that the text-to-speech response has finished playing."""
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
|
||||||
|
if self._tts_finished_event is not None:
|
||||||
|
self._tts_finished_event.set()
|
||||||
|
|
||||||
|
def _resolve_pipeline(self, pipeline_entity_id: str | None) -> str | None:
|
||||||
|
"""Resolve pipeline from select entity to id."""
|
||||||
|
if not pipeline_entity_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
|
||||||
|
raise ValueError("Pipeline entity not found")
|
||||||
|
|
||||||
|
if pipeline_entity_state.state != OPTION_PREFERRED:
|
||||||
|
# Resolve pipeline by name
|
||||||
|
for pipeline in async_get_pipelines(self.hass):
|
||||||
|
if pipeline.name == pipeline_entity_state.state:
|
||||||
|
return pipeline.id
|
||||||
|
|
||||||
|
return None
|
11
homeassistant/components/assist_satellite/errors.py
Normal file
11
homeassistant/components/assist_satellite/errors.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
"""Errors for assist satellite."""
|
||||||
|
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteError(HomeAssistantError):
|
||||||
|
"""Base class for assist satellite errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class SatelliteBusyError(AssistSatelliteError):
|
||||||
|
"""Satellite is busy and cannot handle the request."""
|
7
homeassistant/components/assist_satellite/icons.json
Normal file
7
homeassistant/components/assist_satellite/icons.json
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
{
|
||||||
|
"entity_component": {
|
||||||
|
"_": {
|
||||||
|
"default": "mdi:microphone-message"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"domain": "assist_satellite",
|
||||||
|
"name": "Assist Satellite",
|
||||||
|
"codeowners": ["@synesthesiam"],
|
||||||
|
"config_flow": false,
|
||||||
|
"dependencies": ["assist_pipeline", "stt", "tts"],
|
||||||
|
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||||
|
"integration_type": "entity"
|
||||||
|
}
|
26
homeassistant/components/assist_satellite/models.py
Normal file
26
homeassistant/components/assist_satellite/models.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
"""Models for assist satellite."""
|
||||||
|
|
||||||
|
from enum import IntFlag, StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteState(StrEnum):
|
||||||
|
"""Valid states of an Assist satellite entity."""
|
||||||
|
|
||||||
|
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||||
|
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||||
|
|
||||||
|
LISTENING_COMMAND = "listening_command"
|
||||||
|
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||||
|
|
||||||
|
PROCESSING = "processing"
|
||||||
|
"""Home Assistant is processing the voice command."""
|
||||||
|
|
||||||
|
RESPONDING = "responding"
|
||||||
|
"""Device is speaking the response."""
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteEntityFeature(IntFlag):
|
||||||
|
"""Supported features of Assist satellite entity."""
|
||||||
|
|
||||||
|
ANNOUNCE = 1
|
||||||
|
"""Device supports remotely triggered announcements."""
|
13
homeassistant/components/assist_satellite/strings.json
Normal file
13
homeassistant/components/assist_satellite/strings.json
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"entity_component": {
|
||||||
|
"_": {
|
||||||
|
"name": "Assist satellite",
|
||||||
|
"state": {
|
||||||
|
"listening_wake_word": "Wake word",
|
||||||
|
"listening_command": "Voice command",
|
||||||
|
"responding": "Responding",
|
||||||
|
"processing": "Processing"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
42
homeassistant/components/assist_satellite/websocket_api.py
Normal file
42
homeassistant/components/assist_satellite/websocket_api.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
"""Assist satellite Websocket API."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import websocket_api
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import AssistSatelliteEntity
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
|
"""Register the websocket API."""
|
||||||
|
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_satellite/intercept_wake_word",
|
||||||
|
vol.Required("entity_id"): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_intercept_wake_word(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Intercept the next wake word from a satellite."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
satellite = component.get_entity(msg["entity_id"])
|
||||||
|
if satellite is None:
|
||||||
|
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||||
|
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
509
homeassistant/components/esphome/assist_satellite.py
Normal file
509
homeassistant/components/esphome/assist_satellite.py
Normal file
|
@ -0,0 +1,509 @@
|
||||||
|
"""Support for assist satellites in ESPHome."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from functools import partial
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
from typing import Any, cast
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from aioesphomeapi import (
|
||||||
|
VoiceAssistantAudioSettings,
|
||||||
|
VoiceAssistantCommandFlag,
|
||||||
|
VoiceAssistantEventType,
|
||||||
|
VoiceAssistantFeature,
|
||||||
|
VoiceAssistantTimerEventType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from homeassistant.components import assist_satellite, tts
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
)
|
||||||
|
from homeassistant.components.intent import async_register_timer_handler
|
||||||
|
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
|
||||||
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import EntityCategory, Platform
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import EsphomeAssistEntity
|
||||||
|
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||||
|
from .enum_mapper import EsphomeEnumMapper
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||||
|
VoiceAssistantEventType, PipelineEventType
|
||||||
|
] = EsphomeEnumMapper(
|
||||||
|
{
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
|
||||||
|
EsphomeEnumMapper(
|
||||||
|
{
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entry: ESPHomeConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up Assist satellite entity."""
|
||||||
|
entry_data = entry.runtime_data
|
||||||
|
assert entry_data.device_info is not None
|
||||||
|
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
entry_data.api_version
|
||||||
|
):
|
||||||
|
async_add_entities(
|
||||||
|
[
|
||||||
|
EsphomeAssistSatellite(hass, entry, entry_data),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EsphomeAssistSatellite(
|
||||||
|
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
|
||||||
|
):
|
||||||
|
"""Satellite running ESPHome."""
|
||||||
|
|
||||||
|
entity_description = assist_satellite.AssistSatelliteEntityDescription(
|
||||||
|
key="assist_satellite",
|
||||||
|
translation_key="assist_satellite",
|
||||||
|
entity_category=EntityCategory.CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
entry_data: RuntimeEntryData,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize satellite."""
|
||||||
|
super().__init__(entry_data)
|
||||||
|
|
||||||
|
self.hass = hass
|
||||||
|
self.config_entry = config_entry
|
||||||
|
self.entry_data = entry_data
|
||||||
|
self.cli = self.entry_data.client
|
||||||
|
|
||||||
|
self._is_running: bool = True
|
||||||
|
self._pipeline_task: asyncio.Task | None = None
|
||||||
|
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||||
|
self._tts_streaming_task: asyncio.Task | None = None
|
||||||
|
self._udp_server: VoiceAssistantUDPServer | None = None
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Run when entity about to be added to hass."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
feature_flags = (
|
||||||
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if feature_flags & VoiceAssistantFeature.API_AUDIO:
|
||||||
|
# TCP audio
|
||||||
|
self.entry_data.disconnect_callbacks.add(
|
||||||
|
self.cli.subscribe_voice_assistant(
|
||||||
|
handle_start=self.handle_pipeline_start,
|
||||||
|
handle_stop=self.handle_pipeline_stop,
|
||||||
|
handle_audio=self.handle_audio,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# UDP audio
|
||||||
|
self.entry_data.disconnect_callbacks.add(
|
||||||
|
self.cli.subscribe_voice_assistant(
|
||||||
|
handle_start=self.handle_pipeline_start,
|
||||||
|
handle_stop=self.handle_pipeline_stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_flags & VoiceAssistantFeature.TIMERS:
|
||||||
|
# Device supports timers
|
||||||
|
assert (self.registry_entry is not None) and (
|
||||||
|
self.registry_entry.device_id is not None
|
||||||
|
)
|
||||||
|
self.entry_data.disconnect_callbacks.add(
|
||||||
|
async_register_timer_handler(
|
||||||
|
self.hass, self.registry_entry.device_id, self.handle_timer_event
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_flags & VoiceAssistantFeature.ANNOUNCE:
|
||||||
|
# Device supports announcements
|
||||||
|
self._attr_supported_features |= (
|
||||||
|
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
|
"""Run when entity will be removed from hass."""
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_pipeline()
|
||||||
|
|
||||||
|
async def _internal_async_announce(self, media_id: str) -> None:
|
||||||
|
self.cli.send_voice_assistant_announce(media_id)
|
||||||
|
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Handle pipeline events."""
|
||||||
|
try:
|
||||||
|
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
||||||
|
except KeyError:
|
||||||
|
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
|
||||||
|
return
|
||||||
|
|
||||||
|
data_to_send: dict[str, Any] = {}
|
||||||
|
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
|
||||||
|
self.entry_data.async_set_assist_pipeline_state(True)
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {
|
||||||
|
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
||||||
|
}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {"text": event.data["tts_input"]}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||||
|
assert event.data is not None
|
||||||
|
tts_output = event.data["tts_output"]
|
||||||
|
if tts_output:
|
||||||
|
path = tts_output["url"]
|
||||||
|
url = async_process_play_media_url(self.hass, path)
|
||||||
|
data_to_send = {"url": url}
|
||||||
|
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
feature_flags = (
|
||||||
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||||
|
media_id = tts_output["media_id"]
|
||||||
|
self._tts_streaming_task = (
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._stream_tts_audio(media_id),
|
||||||
|
"esphome_voice_assistant_tts",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||||
|
assert event.data is not None
|
||||||
|
if not event.data["wake_word_output"]:
|
||||||
|
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||||
|
data_to_send = {
|
||||||
|
"code": "no_wake_word",
|
||||||
|
"message": "No wake word detected",
|
||||||
|
}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {
|
||||||
|
"code": event.data["code"],
|
||||||
|
"message": event.data["message"],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
||||||
|
|
||||||
|
async def handle_pipeline_start(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
flags: int,
|
||||||
|
audio_settings: VoiceAssistantAudioSettings,
|
||||||
|
wake_word_phrase: str | None,
|
||||||
|
) -> int | None:
|
||||||
|
"""Handle pipeline run request."""
|
||||||
|
# Clear audio queue
|
||||||
|
while not self._audio_queue.empty():
|
||||||
|
await self._audio_queue.get()
|
||||||
|
|
||||||
|
if self._tts_streaming_task is not None:
|
||||||
|
# Cancel current TTS response
|
||||||
|
self._tts_streaming_task.cancel()
|
||||||
|
self._tts_streaming_task = None
|
||||||
|
|
||||||
|
# API or UDP output audio
|
||||||
|
port: int = 0
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
feature_flags = (
|
||||||
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
|
||||||
|
feature_flags & VoiceAssistantFeature.API_AUDIO
|
||||||
|
):
|
||||||
|
port = await self._start_udp_server()
|
||||||
|
_LOGGER.debug("Started UDP server on port %s", port)
|
||||||
|
|
||||||
|
# Get entity ids for pipeline and finished speaking detection
|
||||||
|
ent_reg = er.async_get(self.hass)
|
||||||
|
pipeline_entity_id = ent_reg.async_get_entity_id(
|
||||||
|
Platform.SELECT,
|
||||||
|
DOMAIN,
|
||||||
|
f"{self.entry_data.device_info.mac_address}-pipeline",
|
||||||
|
)
|
||||||
|
vad_sensitivity_entity_id = ent_reg.async_get_entity_id(
|
||||||
|
Platform.SELECT,
|
||||||
|
DOMAIN,
|
||||||
|
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Device triggered pipeline (wake word, etc.)
|
||||||
|
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||||
|
start_stage = PipelineStage.WAKE_WORD
|
||||||
|
else:
|
||||||
|
start_stage = PipelineStage.STT
|
||||||
|
|
||||||
|
end_stage = PipelineStage.TTS
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||||
|
self.entry_data.async_set_assist_pipeline_state(True)
|
||||||
|
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._async_accept_pipeline_from_satellite(
|
||||||
|
audio_stream=self._wrap_audio_stream(),
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
|
pipeline_entity_id=pipeline_entity_id,
|
||||||
|
vad_sensitivity_entity_id=vad_sensitivity_entity_id,
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
),
|
||||||
|
"esphome_assist_satellite_pipeline",
|
||||||
|
)
|
||||||
|
self._pipeline_task.add_done_callback(
|
||||||
|
lambda _future: self.handle_pipeline_finished()
|
||||||
|
)
|
||||||
|
|
||||||
|
return port
|
||||||
|
|
||||||
|
async def handle_audio(self, data: bytes) -> None:
|
||||||
|
"""Handle incoming audio chunk from API."""
|
||||||
|
self._audio_queue.put_nowait(data)
|
||||||
|
|
||||||
|
async def handle_pipeline_stop(self) -> None:
|
||||||
|
"""Handle request for pipeline to stop."""
|
||||||
|
self._stop_pipeline()
|
||||||
|
|
||||||
|
def handle_pipeline_finished(self) -> None:
|
||||||
|
"""Handle when pipeline has finished running."""
|
||||||
|
self.entry_data.async_set_assist_pipeline_state(False)
|
||||||
|
self._stop_udp_server()
|
||||||
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
|
||||||
|
def handle_timer_event(
|
||||||
|
self, event_type: TimerEventType, timer_info: TimerInfo
|
||||||
|
) -> None:
|
||||||
|
"""Handle timer events."""
|
||||||
|
try:
|
||||||
|
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
|
||||||
|
except KeyError:
|
||||||
|
_LOGGER.debug("Received unknown timer event type: %s", event_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cli.send_voice_assistant_timer_event(
|
||||||
|
native_event_type,
|
||||||
|
timer_info.id,
|
||||||
|
timer_info.name,
|
||||||
|
timer_info.created_seconds,
|
||||||
|
timer_info.seconds_left,
|
||||||
|
timer_info.is_active,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_tts_audio(
|
||||||
|
self,
|
||||||
|
media_id: str,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
sample_width: int = 2,
|
||||||
|
sample_channels: int = 1,
|
||||||
|
samples_per_chunk: int = 512,
|
||||||
|
) -> None:
|
||||||
|
"""Stream TTS audio chunks to device via API or UDP."""
|
||||||
|
self.cli.send_voice_assistant_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self._is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
self.hass,
|
||||||
|
media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if extension != "wav":
|
||||||
|
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||||
|
|
||||||
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
|
if (
|
||||||
|
(wav_file.getframerate() != sample_rate)
|
||||||
|
or (wav_file.getsampwidth() != sample_width)
|
||||||
|
or (wav_file.getnchannels() != sample_channels)
|
||||||
|
):
|
||||||
|
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
|
||||||
|
return
|
||||||
|
|
||||||
|
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
chunk = wav_file.readframes(samples_per_chunk)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._udp_server is not None:
|
||||||
|
self._udp_server.send_audio_bytes(chunk)
|
||||||
|
else:
|
||||||
|
self.cli.send_voice_assistant_audio(chunk)
|
||||||
|
|
||||||
|
# Wait for 90% of the duration of the audio that was
|
||||||
|
# sent for it to be played. This will overrun the
|
||||||
|
# device's buffer for very long audio, so using a media
|
||||||
|
# player is preferred.
|
||||||
|
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||||
|
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||||
|
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return # Don't trigger state change
|
||||||
|
finally:
|
||||||
|
self.cli.send_voice_assistant_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# State change
|
||||||
|
self.tts_response_finished()
|
||||||
|
|
||||||
|
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
|
||||||
|
"""Yield audio chunks from the queue until None."""
|
||||||
|
while True:
|
||||||
|
chunk = await self._audio_queue.get()
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _stop_pipeline(self) -> None:
|
||||||
|
"""Request pipeline to be stopped."""
|
||||||
|
self._audio_queue.put_nowait(None)
|
||||||
|
_LOGGER.debug("Requested pipeline stop")
|
||||||
|
|
||||||
|
async def _start_udp_server(self) -> int:
|
||||||
|
"""Start a UDP server on a random free port."""
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock.setblocking(False)
|
||||||
|
sock.bind(("", 0)) # random free port
|
||||||
|
|
||||||
|
(
|
||||||
|
_transport,
|
||||||
|
protocol,
|
||||||
|
) = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
|
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(protocol, VoiceAssistantUDPServer)
|
||||||
|
self._udp_server = protocol
|
||||||
|
|
||||||
|
# Return port
|
||||||
|
return cast(int, sock.getsockname()[1])
|
||||||
|
|
||||||
|
def _stop_udp_server(self) -> None:
|
||||||
|
"""Stop the UDP server if it's running."""
|
||||||
|
if self._udp_server is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._udp_server.close()
|
||||||
|
finally:
|
||||||
|
self._udp_server = None
|
||||||
|
|
||||||
|
_LOGGER.debug("Stopped UDP server")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
|
"""Receive UDP packets and forward them to the audio queue."""
|
||||||
|
|
||||||
|
transport: asyncio.DatagramTransport | None = None
|
||||||
|
remote_addr: tuple[str, int] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize protocol."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._audio_queue = audio_queue
|
||||||
|
|
||||||
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
|
"""Store transport for later use."""
|
||||||
|
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||||
|
"""Handle incoming UDP packet."""
|
||||||
|
if self.remote_addr is None:
|
||||||
|
self.remote_addr = addr
|
||||||
|
|
||||||
|
self._audio_queue.put_nowait(data)
|
||||||
|
|
||||||
|
def error_received(self, exc: Exception) -> None:
|
||||||
|
"""Handle when a send or receive operation raises an OSError.
|
||||||
|
|
||||||
|
(Other than BlockingIOError or InterruptedError.)
|
||||||
|
"""
|
||||||
|
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||||
|
|
||||||
|
# Stop pipeline
|
||||||
|
self._audio_queue.put_nowait(None)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the receiver."""
|
||||||
|
if self.transport is not None:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
self.remote_addr = None
|
||||||
|
|
||||||
|
def send_audio_bytes(self, data: bytes) -> None:
|
||||||
|
"""Send bytes to the device via UDP."""
|
||||||
|
if self.transport is None:
|
||||||
|
_LOGGER.error("No transport to send audio to")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.remote_addr is None:
|
||||||
|
_LOGGER.error("No address to send audio to")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.transport.sendto(data, self.remote_addr)
|
|
@ -27,12 +27,12 @@ from awesomeversion import AwesomeVersion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import tag, zeroconf
|
from homeassistant.components import tag, zeroconf
|
||||||
from homeassistant.components.intent import async_register_timer_handler
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DEVICE_ID,
|
ATTR_DEVICE_ID,
|
||||||
CONF_MODE,
|
CONF_MODE,
|
||||||
EVENT_HOMEASSISTANT_CLOSE,
|
EVENT_HOMEASSISTANT_CLOSE,
|
||||||
EVENT_LOGGING_CHANGED,
|
EVENT_LOGGING_CHANGED,
|
||||||
|
Platform,
|
||||||
)
|
)
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
Event,
|
Event,
|
||||||
|
@ -77,7 +77,6 @@ from .voice_assistant import (
|
||||||
VoiceAssistantAPIPipeline,
|
VoiceAssistantAPIPipeline,
|
||||||
VoiceAssistantPipeline,
|
VoiceAssistantPipeline,
|
||||||
VoiceAssistantUDPPipeline,
|
VoiceAssistantUDPPipeline,
|
||||||
handle_timer_event,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -500,29 +499,14 @@ class ESPHomeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
flags = device_info.voice_assistant_feature_flags_compat(api_version)
|
if device_info.voice_assistant_feature_flags_compat(api_version) and (
|
||||||
if flags:
|
Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms
|
||||||
if flags & VoiceAssistantFeature.API_AUDIO:
|
):
|
||||||
entry_data.disconnect_callbacks.add(
|
# Create assist satellite entity
|
||||||
cli.subscribe_voice_assistant(
|
await self.hass.config_entries.async_forward_entry_setups(
|
||||||
handle_start=self._handle_pipeline_start,
|
self.entry, [Platform.ASSIST_SATELLITE]
|
||||||
handle_stop=self._handle_pipeline_stop,
|
|
||||||
handle_audio=self._handle_audio,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
entry_data.disconnect_callbacks.add(
|
|
||||||
cli.subscribe_voice_assistant(
|
|
||||||
handle_start=self._handle_pipeline_start,
|
|
||||||
handle_stop=self._handle_pipeline_stop,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if flags & VoiceAssistantFeature.TIMERS:
|
|
||||||
entry_data.disconnect_callbacks.add(
|
|
||||||
async_register_timer_handler(
|
|
||||||
hass, self.device_id, partial(handle_timer_event, cli)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE)
|
||||||
|
|
||||||
cli.subscribe_states(entry_data.async_update_state)
|
cli.subscribe_states(entry_data.async_update_state)
|
||||||
cli.subscribe_service_calls(self.async_on_service_call)
|
cli.subscribe_service_calls(self.async_on_service_call)
|
||||||
|
@ -844,4 +828,5 @@ async def cleanup_instance(
|
||||||
cleanup_callback()
|
cleanup_callback()
|
||||||
await data.async_cleanup()
|
await data.async_cleanup()
|
||||||
await data.client.disconnect()
|
await data.client.disconnect()
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -59,6 +59,17 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"entity": {
|
"entity": {
|
||||||
|
"assist_satellite": {
|
||||||
|
"assist_satellite": {
|
||||||
|
"name": "[%key:component::assist_satellite::entity_component::_::name%]",
|
||||||
|
"state": {
|
||||||
|
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
|
||||||
|
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
|
||||||
|
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
|
||||||
|
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"binary_sensor": {
|
"binary_sensor": {
|
||||||
"assist_in_progress": {
|
"assist_in_progress": {
|
||||||
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
|
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
|
||||||
|
|
|
@ -20,6 +20,7 @@ from .devices import VoIPDevices
|
||||||
from .voip import HassVoipDatagramProtocol
|
from .voip import HassVoipDatagramProtocol
|
||||||
|
|
||||||
PLATFORMS = (
|
PLATFORMS = (
|
||||||
|
Platform.ASSIST_SATELLITE,
|
||||||
Platform.BINARY_SENSOR,
|
Platform.BINARY_SENSOR,
|
||||||
Platform.SELECT,
|
Platform.SELECT,
|
||||||
Platform.SWITCH,
|
Platform.SWITCH,
|
||||||
|
|
306
homeassistant/components/voip/assist_satellite.py
Normal file
306
homeassistant/components/voip/assist_satellite.py
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
"""Assist satellite entity for VoIP integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from enum import IntFlag
|
||||||
|
from functools import partial
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Final
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from voip_utils import RtpDatagramProtocol
|
||||||
|
|
||||||
|
from homeassistant.components import tts
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineNotFound,
|
||||||
|
)
|
||||||
|
from homeassistant.components.assist_satellite import (
|
||||||
|
AssistSatelliteEntity,
|
||||||
|
AssistSatelliteEntityDescription,
|
||||||
|
AssistSatelliteState,
|
||||||
|
)
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.util.async_ import queue_to_iterable
|
||||||
|
|
||||||
|
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||||
|
from .devices import VoIPDevice
|
||||||
|
from .entity import VoIPEntity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import DomainData
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||||
|
|
||||||
|
|
||||||
|
class Tones(IntFlag):
|
||||||
|
"""Feedback tones for specific events."""
|
||||||
|
|
||||||
|
LISTENING = 1
|
||||||
|
PROCESSING = 2
|
||||||
|
ERROR = 4
|
||||||
|
|
||||||
|
|
||||||
|
_TONE_FILENAMES: dict[Tones, str] = {
|
||||||
|
Tones.LISTENING: "tone.pcm",
|
||||||
|
Tones.PROCESSING: "processing.pcm",
|
||||||
|
Tones.ERROR: "error.pcm",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up VoIP Assist satellite entity."""
|
||||||
|
domain_data: DomainData = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_add_device(device: VoIPDevice) -> None:
|
||||||
|
"""Add device."""
|
||||||
|
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
|
||||||
|
|
||||||
|
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||||
|
|
||||||
|
entities: list[VoIPEntity] = [
|
||||||
|
VoipAssistSatellite(hass, device, config_entry)
|
||||||
|
for device in domain_data.devices
|
||||||
|
]
|
||||||
|
|
||||||
|
async_add_entities(entities)
|
||||||
|
|
||||||
|
|
||||||
|
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
|
||||||
|
"""Assist satellite for VoIP devices."""
|
||||||
|
|
||||||
|
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||||
|
_attr_translation_key = "assist_satellite"
|
||||||
|
_attr_has_entity_name = True
|
||||||
|
_attr_name = None
|
||||||
|
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize an Assist satellite."""
|
||||||
|
VoIPEntity.__init__(self, voip_device)
|
||||||
|
AssistSatelliteEntity.__init__(self)
|
||||||
|
RtpDatagramProtocol.__init__(self)
|
||||||
|
|
||||||
|
self.config_entry = config_entry
|
||||||
|
|
||||||
|
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||||
|
self._audio_chunk_timeout: float = 2.0
|
||||||
|
self._pipeline_task: asyncio.Task | None = None
|
||||||
|
self._pipeline_had_error: bool = False
|
||||||
|
self._tts_done = asyncio.Event()
|
||||||
|
self._tts_extra_timeout: float = 1.0
|
||||||
|
self._tone_bytes: dict[Tones, bytes] = {}
|
||||||
|
self._tones = tones
|
||||||
|
self._processing_tone_done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Run when entity about to be added to hass."""
|
||||||
|
self.voip_device.protocol = self
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
|
"""Run when entity will be removed from hass."""
|
||||||
|
assert self.voip_device.protocol == self
|
||||||
|
self.voip_device.protocol = None
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# VoIP
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||||
|
"""Handle raw audio chunk."""
|
||||||
|
if self._pipeline_task is None:
|
||||||
|
self._clear_audio_queue()
|
||||||
|
|
||||||
|
# Run pipeline until voice command finishes, then start over
|
||||||
|
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._run_pipeline(),
|
||||||
|
"voip_pipeline_run",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._audio_queue.put_nowait(audio_bytes)
|
||||||
|
|
||||||
|
async def _run_pipeline(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Forward audio to pipeline STT and handle TTS."""
|
||||||
|
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||||
|
self.voip_device.set_is_active(True)
|
||||||
|
|
||||||
|
# Play listening tone at the start of each cycle
|
||||||
|
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._tts_done.clear()
|
||||||
|
|
||||||
|
# Run pipeline with a timeout
|
||||||
|
_LOGGER.debug("Starting pipeline")
|
||||||
|
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
||||||
|
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
|
||||||
|
audio_stream=queue_to_iterable(
|
||||||
|
self._audio_queue, timeout=self._audio_chunk_timeout
|
||||||
|
),
|
||||||
|
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
|
||||||
|
self.hass
|
||||||
|
),
|
||||||
|
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
|
||||||
|
self.hass
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._pipeline_had_error:
|
||||||
|
self._pipeline_had_error = False
|
||||||
|
await self._play_tone(Tones.ERROR)
|
||||||
|
else:
|
||||||
|
# Block until TTS is done speaking.
|
||||||
|
#
|
||||||
|
# This is set in _send_tts and has a timeout that's based on the
|
||||||
|
# length of the TTS audio.
|
||||||
|
await self._tts_done.wait()
|
||||||
|
|
||||||
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
except PipelineNotFound:
|
||||||
|
_LOGGER.warning("Pipeline not found")
|
||||||
|
except (asyncio.CancelledError, TimeoutError):
|
||||||
|
# Expected after caller hangs up
|
||||||
|
_LOGGER.debug("Pipeline cancelled or timed out")
|
||||||
|
self.disconnect()
|
||||||
|
self._clear_audio_queue()
|
||||||
|
finally:
|
||||||
|
self.voip_device.set_is_active(False)
|
||||||
|
|
||||||
|
# Allow pipeline to run again
|
||||||
|
self._pipeline_task = None
|
||||||
|
|
||||||
|
def _clear_audio_queue(self) -> None:
|
||||||
|
"""Ensure audio queue is empty."""
|
||||||
|
while not self._audio_queue.empty():
|
||||||
|
self._audio_queue.get_nowait()
|
||||||
|
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Set state based on pipeline stage."""
|
||||||
|
if event.type == PipelineEventType.STT_END:
|
||||||
|
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||||
|
self._processing_tone_done.clear()
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
|
||||||
|
)
|
||||||
|
elif event.type == PipelineEventType.TTS_END:
|
||||||
|
# Send TTS audio to caller over RTP
|
||||||
|
if event.data and (tts_output := event.data["tts_output"]):
|
||||||
|
media_id = tts_output["media_id"]
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._send_tts(media_id),
|
||||||
|
"voip_pipeline_tts",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Empty TTS response
|
||||||
|
self._tts_done.set()
|
||||||
|
elif event.type == PipelineEventType.ERROR:
|
||||||
|
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||||
|
self._pipeline_had_error = True
|
||||||
|
|
||||||
|
async def _send_tts(self, media_id: str) -> None:
|
||||||
|
"""Send TTS audio to caller via RTP."""
|
||||||
|
try:
|
||||||
|
if self.transport is None:
|
||||||
|
return # not connected
|
||||||
|
|
||||||
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
self.hass,
|
||||||
|
media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if extension != "wav":
|
||||||
|
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||||
|
|
||||||
|
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||||
|
# Don't overlap TTS and processing beep
|
||||||
|
await self._processing_tone_done.wait()
|
||||||
|
|
||||||
|
with io.BytesIO(data) as wav_io:
|
||||||
|
with wave.open(wav_io, "rb") as wav_file:
|
||||||
|
sample_rate = wav_file.getframerate()
|
||||||
|
sample_width = wav_file.getsampwidth()
|
||||||
|
sample_channels = wav_file.getnchannels()
|
||||||
|
|
||||||
|
if (
|
||||||
|
(sample_rate != RATE)
|
||||||
|
or (sample_width != WIDTH)
|
||||||
|
or (sample_channels != CHANNELS)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||||
|
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||||
|
|
||||||
|
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||||
|
|
||||||
|
# Time out 1 second after TTS audio should be finished
|
||||||
|
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||||
|
tts_seconds = tts_samples / RATE
|
||||||
|
|
||||||
|
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
|
||||||
|
# TTS audio is 16Khz 16-bit mono
|
||||||
|
await self._async_send_audio(audio_bytes)
|
||||||
|
except TimeoutError:
|
||||||
|
_LOGGER.warning("TTS timeout")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Signal pipeline to restart
|
||||||
|
self._tts_done.set()
|
||||||
|
|
||||||
|
# Update satellite state
|
||||||
|
self.tts_response_finished()
|
||||||
|
|
||||||
|
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||||
|
"""Send audio in executor."""
|
||||||
|
await self.hass.async_add_executor_job(
|
||||||
|
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
|
||||||
|
"""Play a tone as feedback to the user if it's enabled."""
|
||||||
|
if (self._tones & tone) != tone:
|
||||||
|
return # not enabled
|
||||||
|
|
||||||
|
if tone not in self._tone_bytes:
|
||||||
|
# Do I/O in executor
|
||||||
|
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
|
||||||
|
self._load_pcm,
|
||||||
|
_TONE_FILENAMES[tone],
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._async_send_audio(
|
||||||
|
self._tone_bytes[tone],
|
||||||
|
silence_before=silence_before,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tone == Tones.PROCESSING:
|
||||||
|
self._processing_tone_done.set()
|
||||||
|
|
||||||
|
def _load_pcm(self, file_name: str) -> bytes:
|
||||||
|
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||||
|
return (Path(__file__).parent / file_name).read_bytes()
|
|
@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
|
||||||
"""Call when entity about to be added to hass."""
|
"""Call when entity about to be added to hass."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
|
self.async_on_remove(
|
||||||
|
self.voip_device.async_listen_update(self._is_active_changed)
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _is_active_changed(self, device: VoIPDevice) -> None:
|
def _is_active_changed(self, device: VoIPDevice) -> None:
|
||||||
"""Call when active state changed."""
|
"""Call when active state changed."""
|
||||||
self._attr_is_on = self._device.is_active
|
self._attr_is_on = self.voip_device.is_active
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
from collections.abc import Callable, Iterator
|
from collections.abc import Callable, Iterator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from voip_utils import CallInfo
|
from voip_utils import CallInfo, VoipDatagramProtocol
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import Event, HomeAssistant, callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
|
@ -22,6 +22,7 @@ class VoIPDevice:
|
||||||
device_id: str
|
device_id: str
|
||||||
is_active: bool = False
|
is_active: bool = False
|
||||||
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
|
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
|
||||||
|
protocol: VoipDatagramProtocol | None = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def set_is_active(self, active: bool) -> None:
|
def set_is_active(self, active: bool) -> None:
|
||||||
|
@ -56,6 +57,18 @@ class VoIPDevice:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return entity id for pipeline select."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
|
||||||
|
|
||||||
|
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return entity id for VAD sensitivity."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
return ent_reg.async_get_entity_id(
|
||||||
|
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoIPDevices:
|
class VoIPDevices:
|
||||||
"""Class to store devices."""
|
"""Class to store devices."""
|
||||||
|
|
|
@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
|
||||||
_attr_has_entity_name = True
|
_attr_has_entity_name = True
|
||||||
_attr_should_poll = False
|
_attr_should_poll = False
|
||||||
|
|
||||||
def __init__(self, device: VoIPDevice) -> None:
|
def __init__(self, voip_device: VoIPDevice) -> None:
|
||||||
"""Initialize VoIP entity."""
|
"""Initialize VoIP entity."""
|
||||||
self._device = device
|
self.voip_device = voip_device
|
||||||
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
|
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
|
||||||
self._attr_device_info = DeviceInfo(
|
self._attr_device_info = DeviceInfo(
|
||||||
identifiers={(DOMAIN, device.voip_id)},
|
identifiers={(DOMAIN, voip_device.voip_id)},
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
"name": "Voice over IP",
|
"name": "Voice over IP",
|
||||||
"codeowners": ["@balloob", "@synesthesiam"],
|
"codeowners": ["@balloob", "@synesthesiam"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
"dependencies": ["assist_pipeline"],
|
"dependencies": ["assist_pipeline", "assist_satellite"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"quality_scale": "internal",
|
"quality_scale": "internal",
|
||||||
|
|
|
@ -10,6 +10,17 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"entity": {
|
"entity": {
|
||||||
|
"assist_satellite": {
|
||||||
|
"assist_satellite": {
|
||||||
|
"name": "[%key:component::assist_satellite::entity_component::_::name%]",
|
||||||
|
"state": {
|
||||||
|
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
|
||||||
|
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
|
||||||
|
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
|
||||||
|
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"binary_sensor": {
|
"binary_sensor": {
|
||||||
"call_in_progress": {
|
"call_in_progress": {
|
||||||
"name": "Call in progress"
|
"name": "Call in progress"
|
||||||
|
|
|
@ -3,15 +3,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
|
||||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
import wave
|
|
||||||
|
|
||||||
from voip_utils import (
|
from voip_utils import (
|
||||||
CallInfo,
|
CallInfo,
|
||||||
|
@ -21,33 +17,19 @@ from voip_utils import (
|
||||||
VoipDatagramProtocol,
|
VoipDatagramProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, stt, tts
|
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineEvent,
|
|
||||||
PipelineEventType,
|
|
||||||
PipelineNotFound,
|
PipelineNotFound,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
async_pipeline_from_audio_stream,
|
|
||||||
select as pipeline_select,
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_pipeline.audio_enhancer import (
|
|
||||||
AudioEnhancer,
|
|
||||||
MicroVadEnhancer,
|
|
||||||
)
|
|
||||||
from homeassistant.components.assist_pipeline.vad import (
|
|
||||||
AudioBuffer,
|
|
||||||
VadSensitivity,
|
|
||||||
VoiceCommandSegmenter,
|
|
||||||
)
|
|
||||||
from homeassistant.const import __version__
|
from homeassistant.const import __version__
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.util.ulid import ulid_now
|
|
||||||
|
|
||||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .devices import VoIPDevice, VoIPDevices
|
from .devices import VoIPDevices
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -60,11 +42,8 @@ def make_protocol(
|
||||||
) -> VoipDatagramProtocol:
|
) -> VoipDatagramProtocol:
|
||||||
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
||||||
voip_device = devices.async_get_or_create(call_info)
|
voip_device = devices.async_get_or_create(call_info)
|
||||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
|
||||||
hass,
|
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
|
||||||
DOMAIN,
|
|
||||||
voip_device.voip_id,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||||
except PipelineNotFound:
|
except PipelineNotFound:
|
||||||
|
@ -83,22 +62,18 @@ def make_protocol(
|
||||||
rtcp_state=rtcp_state,
|
rtcp_state=rtcp_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
if (protocol := voip_device.protocol) is None:
|
||||||
hass,
|
raise ValueError("VoIP satellite not found")
|
||||||
DOMAIN,
|
|
||||||
voip_device.voip_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pipeline is properly configured
|
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||||
return PipelineRtpDatagramProtocol(
|
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||||
hass,
|
|
||||||
hass.config.language,
|
protocol.rtcp_state = rtcp_state
|
||||||
voip_device,
|
if protocol.rtcp_state is not None:
|
||||||
Context(user_id=devices.config_entry.data["user"]),
|
# Automatically disconnect when BYE is received over RTCP
|
||||||
opus_payload_type=call_info.opus_payload_type,
|
protocol.rtcp_state.bye_callback = protocol.disconnect
|
||||||
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
|
||||||
rtcp_state=rtcp_state,
|
return protocol
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||||
|
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||||
await self._closed_event.wait()
|
await self._closed_event.wait()
|
||||||
|
|
||||||
|
|
||||||
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|
||||||
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
language: str,
|
|
||||||
voip_device: VoIPDevice,
|
|
||||||
context: Context,
|
|
||||||
opus_payload_type: int,
|
|
||||||
pipeline_timeout: float = 30.0,
|
|
||||||
audio_timeout: float = 2.0,
|
|
||||||
buffered_chunks_before_speech: int = 100,
|
|
||||||
listening_tone_enabled: bool = True,
|
|
||||||
processing_tone_enabled: bool = True,
|
|
||||||
error_tone_enabled: bool = True,
|
|
||||||
tone_delay: float = 0.2,
|
|
||||||
tts_extra_timeout: float = 1.0,
|
|
||||||
silence_seconds: float = 1.0,
|
|
||||||
rtcp_state: RtcpState | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Set up pipeline RTP server."""
|
|
||||||
super().__init__(
|
|
||||||
rate=RATE,
|
|
||||||
width=WIDTH,
|
|
||||||
channels=CHANNELS,
|
|
||||||
opus_payload_type=opus_payload_type,
|
|
||||||
rtcp_state=rtcp_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hass = hass
|
|
||||||
self.language = language
|
|
||||||
self.voip_device = voip_device
|
|
||||||
self.pipeline: Pipeline | None = None
|
|
||||||
self.pipeline_timeout = pipeline_timeout
|
|
||||||
self.audio_timeout = audio_timeout
|
|
||||||
self.buffered_chunks_before_speech = buffered_chunks_before_speech
|
|
||||||
self.listening_tone_enabled = listening_tone_enabled
|
|
||||||
self.processing_tone_enabled = processing_tone_enabled
|
|
||||||
self.error_tone_enabled = error_tone_enabled
|
|
||||||
self.tone_delay = tone_delay
|
|
||||||
self.tts_extra_timeout = tts_extra_timeout
|
|
||||||
self.silence_seconds = silence_seconds
|
|
||||||
|
|
||||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
|
||||||
self._context = context
|
|
||||||
self._conversation_id: str | None = None
|
|
||||||
self._pipeline_task: asyncio.Task | None = None
|
|
||||||
self._tts_done = asyncio.Event()
|
|
||||||
self._session_id: str | None = None
|
|
||||||
self._tone_bytes: bytes | None = None
|
|
||||||
self._processing_bytes: bytes | None = None
|
|
||||||
self._error_bytes: bytes | None = None
|
|
||||||
self._pipeline_error: bool = False
|
|
||||||
|
|
||||||
def connection_made(self, transport):
|
|
||||||
"""Server is ready."""
|
|
||||||
super().connection_made(transport)
|
|
||||||
self.voip_device.set_is_active(True)
|
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
|
||||||
"""Handle connection is lost or closed."""
|
|
||||||
super().connection_lost(exc)
|
|
||||||
self.voip_device.set_is_active(False)
|
|
||||||
|
|
||||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
|
||||||
"""Handle raw audio chunk."""
|
|
||||||
if self._pipeline_task is None:
|
|
||||||
self._clear_audio_queue()
|
|
||||||
|
|
||||||
# Run pipeline until voice command finishes, then start over
|
|
||||||
self._pipeline_task = self.hass.async_create_background_task(
|
|
||||||
self._run_pipeline(),
|
|
||||||
"voip_pipeline_run",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._audio_queue.put_nowait(audio_bytes)
|
|
||||||
|
|
||||||
async def _run_pipeline(
|
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
"""Forward audio to pipeline STT and handle TTS."""
|
|
||||||
if self._session_id is None:
|
|
||||||
self._session_id = ulid_now()
|
|
||||||
|
|
||||||
# Play listening tone at the start of each cycle
|
|
||||||
if self.listening_tone_enabled:
|
|
||||||
await self._play_listening_tone()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for speech before starting pipeline
|
|
||||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
|
||||||
audio_enhancer = MicroVadEnhancer(0, 0, True)
|
|
||||||
chunk_buffer: deque[bytes] = deque(
|
|
||||||
maxlen=self.buffered_chunks_before_speech,
|
|
||||||
)
|
|
||||||
speech_detected = await self._wait_for_speech(
|
|
||||||
segmenter,
|
|
||||||
audio_enhancer,
|
|
||||||
chunk_buffer,
|
|
||||||
)
|
|
||||||
if not speech_detected:
|
|
||||||
_LOGGER.debug("No speech detected")
|
|
||||||
return
|
|
||||||
|
|
||||||
_LOGGER.debug("Starting pipeline")
|
|
||||||
self._tts_done.clear()
|
|
||||||
|
|
||||||
async def stt_stream():
|
|
||||||
try:
|
|
||||||
async for chunk in self._segment_audio(
|
|
||||||
segmenter,
|
|
||||||
audio_enhancer,
|
|
||||||
chunk_buffer,
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
if self.processing_tone_enabled:
|
|
||||||
await self._play_processing_tone()
|
|
||||||
except TimeoutError:
|
|
||||||
# Expected after caller hangs up
|
|
||||||
_LOGGER.debug("Audio timeout")
|
|
||||||
self._session_id = None
|
|
||||||
self.disconnect()
|
|
||||||
finally:
|
|
||||||
self._clear_audio_queue()
|
|
||||||
|
|
||||||
# Run pipeline with a timeout
|
|
||||||
async with asyncio.timeout(self.pipeline_timeout):
|
|
||||||
await async_pipeline_from_audio_stream(
|
|
||||||
self.hass,
|
|
||||||
context=self._context,
|
|
||||||
event_callback=self._event_callback,
|
|
||||||
stt_metadata=stt.SpeechMetadata(
|
|
||||||
language="", # set in async_pipeline_from_audio_stream
|
|
||||||
format=stt.AudioFormats.WAV,
|
|
||||||
codec=stt.AudioCodecs.PCM,
|
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
||||||
),
|
|
||||||
stt_stream=stt_stream(),
|
|
||||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
|
||||||
self.hass, DOMAIN, self.voip_device.voip_id
|
|
||||||
),
|
|
||||||
conversation_id=self._conversation_id,
|
|
||||||
device_id=self.voip_device.device_id,
|
|
||||||
tts_audio_output="wav",
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._pipeline_error:
|
|
||||||
self._pipeline_error = False
|
|
||||||
if self.error_tone_enabled:
|
|
||||||
await self._play_error_tone()
|
|
||||||
else:
|
|
||||||
# Block until TTS is done speaking.
|
|
||||||
#
|
|
||||||
# This is set in _send_tts and has a timeout that's based on the
|
|
||||||
# length of the TTS audio.
|
|
||||||
await self._tts_done.wait()
|
|
||||||
|
|
||||||
_LOGGER.debug("Pipeline finished")
|
|
||||||
except PipelineNotFound:
|
|
||||||
_LOGGER.warning("Pipeline not found")
|
|
||||||
except TimeoutError:
|
|
||||||
# Expected after caller hangs up
|
|
||||||
_LOGGER.debug("Pipeline timeout")
|
|
||||||
self._session_id = None
|
|
||||||
self.disconnect()
|
|
||||||
finally:
|
|
||||||
# Allow pipeline to run again
|
|
||||||
self._pipeline_task = None
|
|
||||||
|
|
||||||
async def _wait_for_speech(
|
|
||||||
self,
|
|
||||||
segmenter: VoiceCommandSegmenter,
|
|
||||||
audio_enhancer: AudioEnhancer,
|
|
||||||
chunk_buffer: MutableSequence[bytes],
|
|
||||||
):
|
|
||||||
"""Buffer audio chunks until speech is detected.
|
|
||||||
|
|
||||||
Returns True if speech was detected, False otherwise.
|
|
||||||
"""
|
|
||||||
# Timeout if no audio comes in for a while.
|
|
||||||
# This means the caller hung up.
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self._audio_queue.get()
|
|
||||||
|
|
||||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
|
||||||
|
|
||||||
while chunk:
|
|
||||||
chunk_buffer.append(chunk)
|
|
||||||
|
|
||||||
segmenter.process_with_vad(
|
|
||||||
chunk,
|
|
||||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
|
||||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
|
||||||
vad_buffer,
|
|
||||||
)
|
|
||||||
if segmenter.in_command:
|
|
||||||
# Buffer until command starts
|
|
||||||
if len(vad_buffer) > 0:
|
|
||||||
chunk_buffer.append(vad_buffer.bytes())
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self._audio_queue.get()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _segment_audio(
|
|
||||||
self,
|
|
||||||
segmenter: VoiceCommandSegmenter,
|
|
||||||
audio_enhancer: AudioEnhancer,
|
|
||||||
chunk_buffer: Sequence[bytes],
|
|
||||||
) -> AsyncIterable[bytes]:
|
|
||||||
"""Yield audio chunks until voice command has finished."""
|
|
||||||
# Buffered chunks first
|
|
||||||
for buffered_chunk in chunk_buffer:
|
|
||||||
yield buffered_chunk
|
|
||||||
|
|
||||||
# Timeout if no audio comes in for a while.
|
|
||||||
# This means the caller hung up.
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self._audio_queue.get()
|
|
||||||
|
|
||||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
|
||||||
|
|
||||||
while chunk:
|
|
||||||
if not segmenter.process_with_vad(
|
|
||||||
chunk,
|
|
||||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
|
||||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
|
||||||
vad_buffer,
|
|
||||||
):
|
|
||||||
# Voice command is finished
|
|
||||||
break
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self._audio_queue.get()
|
|
||||||
|
|
||||||
def _clear_audio_queue(self) -> None:
|
|
||||||
while not self._audio_queue.empty():
|
|
||||||
self._audio_queue.get_nowait()
|
|
||||||
|
|
||||||
def _event_callback(self, event: PipelineEvent):
|
|
||||||
if not event.data:
|
|
||||||
return
|
|
||||||
|
|
||||||
if event.type == PipelineEventType.INTENT_END:
|
|
||||||
# Capture conversation id
|
|
||||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
|
||||||
elif event.type == PipelineEventType.TTS_END:
|
|
||||||
# Send TTS audio to caller over RTP
|
|
||||||
tts_output = event.data["tts_output"]
|
|
||||||
if tts_output:
|
|
||||||
media_id = tts_output["media_id"]
|
|
||||||
self.hass.async_create_background_task(
|
|
||||||
self._send_tts(media_id),
|
|
||||||
"voip_pipeline_tts",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Empty TTS response
|
|
||||||
self._tts_done.set()
|
|
||||||
elif event.type == PipelineEventType.ERROR:
|
|
||||||
# Play error tone instead of wait for TTS
|
|
||||||
self._pipeline_error = True
|
|
||||||
|
|
||||||
async def _send_tts(self, media_id: str) -> None:
|
|
||||||
"""Send TTS audio to caller via RTP."""
|
|
||||||
try:
|
|
||||||
if self.transport is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
|
||||||
self.hass,
|
|
||||||
media_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if extension != "wav":
|
|
||||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
|
||||||
|
|
||||||
with io.BytesIO(data) as wav_io:
|
|
||||||
with wave.open(wav_io, "rb") as wav_file:
|
|
||||||
sample_rate = wav_file.getframerate()
|
|
||||||
sample_width = wav_file.getsampwidth()
|
|
||||||
sample_channels = wav_file.getnchannels()
|
|
||||||
|
|
||||||
if (
|
|
||||||
(sample_rate != RATE)
|
|
||||||
or (sample_width != WIDTH)
|
|
||||||
or (sample_channels != CHANNELS)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
|
||||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
|
||||||
|
|
||||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
|
||||||
|
|
||||||
# Time out 1 second after TTS audio should be finished
|
|
||||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
|
||||||
tts_seconds = tts_samples / RATE
|
|
||||||
|
|
||||||
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
|
||||||
# TTS audio is 16Khz 16-bit mono
|
|
||||||
await self._async_send_audio(audio_bytes)
|
|
||||||
except TimeoutError:
|
|
||||||
_LOGGER.warning("TTS timeout")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# Signal pipeline to restart
|
|
||||||
self._tts_done.set()
|
|
||||||
|
|
||||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
|
||||||
"""Send audio in executor."""
|
|
||||||
await self.hass.async_add_executor_job(
|
|
||||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _play_listening_tone(self) -> None:
|
|
||||||
"""Play a tone to indicate that Home Assistant is listening."""
|
|
||||||
if self._tone_bytes is None:
|
|
||||||
# Do I/O in executor
|
|
||||||
self._tone_bytes = await self.hass.async_add_executor_job(
|
|
||||||
self._load_pcm,
|
|
||||||
"tone.pcm",
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._async_send_audio(
|
|
||||||
self._tone_bytes,
|
|
||||||
silence_before=self.tone_delay,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _play_processing_tone(self) -> None:
|
|
||||||
"""Play a tone to indicate that Home Assistant is processing the voice command."""
|
|
||||||
if self._processing_bytes is None:
|
|
||||||
# Do I/O in executor
|
|
||||||
self._processing_bytes = await self.hass.async_add_executor_job(
|
|
||||||
self._load_pcm,
|
|
||||||
"processing.pcm",
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._async_send_audio(self._processing_bytes)
|
|
||||||
|
|
||||||
async def _play_error_tone(self) -> None:
|
|
||||||
"""Play a tone to indicate a pipeline error occurred."""
|
|
||||||
if self._error_bytes is None:
|
|
||||||
# Do I/O in executor
|
|
||||||
self._error_bytes = await self.hass.async_add_executor_job(
|
|
||||||
self._load_pcm,
|
|
||||||
"error.pcm",
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._async_send_audio(self._error_bytes)
|
|
||||||
|
|
||||||
def _load_pcm(self, file_name: str) -> bytes:
|
|
||||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
|
||||||
return (Path(__file__).parent / file_name).read_bytes()
|
|
||||||
|
|
||||||
|
|
||||||
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
||||||
"""Plays a pre-recorded message on a loop."""
|
"""Plays a pre-recorded message on a loop."""
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
||||||
|
|
||||||
AIR_QUALITY = "air_quality"
|
AIR_QUALITY = "air_quality"
|
||||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||||
|
ASSIST_SATELLITE = "assist_satellite"
|
||||||
BINARY_SENSOR = "binary_sensor"
|
BINARY_SENSOR = "binary_sensor"
|
||||||
BUTTON = "button"
|
BUTTON = "button"
|
||||||
CALENDAR = "calendar"
|
CALENDAR = "calendar"
|
||||||
|
|
|
@ -5,22 +5,28 @@ from __future__ import annotations
|
||||||
from asyncio import (
|
from asyncio import (
|
||||||
AbstractEventLoop,
|
AbstractEventLoop,
|
||||||
Future,
|
Future,
|
||||||
|
Queue,
|
||||||
Semaphore,
|
Semaphore,
|
||||||
Task,
|
Task,
|
||||||
TimerHandle,
|
TimerHandle,
|
||||||
gather,
|
gather,
|
||||||
get_running_loop,
|
get_running_loop,
|
||||||
|
timeout as async_timeout,
|
||||||
)
|
)
|
||||||
from collections.abc import Awaitable, Callable, Coroutine
|
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
|
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
|
||||||
|
|
||||||
|
_DataT = TypeVar("_DataT", default=Any)
|
||||||
|
|
||||||
|
|
||||||
def create_eager_task[_T](
|
def create_eager_task[_T](
|
||||||
coro: Coroutine[Any, Any, _T],
|
coro: Coroutine[Any, Any, _T],
|
||||||
|
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
|
||||||
"""Return a list of scheduled TimerHandles."""
|
"""Return a list of scheduled TimerHandles."""
|
||||||
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
|
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
|
||||||
return handles
|
return handles
|
||||||
|
|
||||||
|
|
||||||
|
async def queue_to_iterable(
|
||||||
|
queue: Queue[_DataT], timeout: float | None = None
|
||||||
|
) -> AsyncIterable[_DataT]:
|
||||||
|
"""Stream items from a queue until None with an optional timeout per item."""
|
||||||
|
if timeout is None:
|
||||||
|
while (item := await queue.get()) is not None:
|
||||||
|
yield item
|
||||||
|
else:
|
||||||
|
async with async_timeout(timeout):
|
||||||
|
item = await queue.get()
|
||||||
|
|
||||||
|
while item is not None:
|
||||||
|
yield item
|
||||||
|
async with async_timeout(timeout):
|
||||||
|
item = await queue.get()
|
||||||
|
|
1
tests/components/assist_satellite/__init__.py
Normal file
1
tests/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for Assist Satellite."""
|
106
tests/components/assist_satellite/conftest.py
Normal file
106
tests/components/assist_satellite/conftest.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
"""Test helpers for Assist Satellite."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||||
|
from homeassistant.components.assist_satellite import (
|
||||||
|
DOMAIN as AS_DOMAIN,
|
||||||
|
AssistSatelliteEntity,
|
||||||
|
AssistSatelliteEntityFeature,
|
||||||
|
)
|
||||||
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import (
|
||||||
|
MockConfigEntry,
|
||||||
|
MockModule,
|
||||||
|
MockPlatform,
|
||||||
|
mock_config_flow,
|
||||||
|
mock_integration,
|
||||||
|
mock_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_DOMAIN = "test_satellite"
|
||||||
|
|
||||||
|
|
||||||
|
class MockAssistSatellite(AssistSatelliteEntity):
|
||||||
|
"""Mock Assist Satellite Entity."""
|
||||||
|
|
||||||
|
_attr_name = "Test Entity"
|
||||||
|
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the mock entity."""
|
||||||
|
self.events = []
|
||||||
|
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Handle pipeline events."""
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def entity() -> MockAssistSatellite:
|
||||||
|
"""Mock Assist Satellite Entity."""
|
||||||
|
return MockAssistSatellite()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||||
|
"""Mock config entry."""
|
||||||
|
entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_components(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Initialize components."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
async def async_setup_entry_init(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Set up test config entry."""
|
||||||
|
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def async_unload_entry_init(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Unload test config entry."""
|
||||||
|
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
|
||||||
|
return True
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
TEST_DOMAIN,
|
||||||
|
async_setup_entry=async_setup_entry_init,
|
||||||
|
async_unload_entry=async_unload_entry_init,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
|
||||||
|
|
||||||
|
async def async_setup_entry_platform(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up test tts platform via config entry."""
|
||||||
|
async_add_entities([entity])
|
||||||
|
|
||||||
|
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||||
|
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
|
||||||
|
|
||||||
|
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||||
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
return config_entry
|
88
tests/components/assist_satellite/test_entity.py
Normal file
88
tests/components/assist_satellite/test_entity.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
"""Test the Assist Satellite entity."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
AudioSettings,
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
vad,
|
||||||
|
)
|
||||||
|
from homeassistant.components.assist_satellite import AssistSatelliteState
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
|
||||||
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
ENTITY_ID = "assist_satellite.test_entity"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entity_state(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test entity state represent events."""
|
||||||
|
|
||||||
|
state = hass.states.get(ENTITY_ID)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
entity.async_set_context(context)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||||
|
) as mock_start_pipeline:
|
||||||
|
await entity._async_accept_pipeline_from_satellite(audio_stream) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert mock_start_pipeline.called
|
||||||
|
kwargs = mock_start_pipeline.call_args[1]
|
||||||
|
assert kwargs["context"] is context
|
||||||
|
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
|
||||||
|
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
|
||||||
|
language="",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
assert kwargs["stt_stream"] is audio_stream
|
||||||
|
assert kwargs["pipeline_id"] is None
|
||||||
|
assert kwargs["device_id"] is None
|
||||||
|
assert kwargs["tts_audio_output"] == "wav"
|
||||||
|
assert kwargs["wake_word_phrase"] is None
|
||||||
|
assert kwargs["audio_settings"] == AudioSettings(
|
||||||
|
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||||
|
)
|
||||||
|
assert kwargs["start_stage"] == PipelineStage.STT
|
||||||
|
assert kwargs["end_stage"] == PipelineStage.TTS
|
||||||
|
|
||||||
|
for event_type, expected_state in (
|
||||||
|
(PipelineEventType.RUN_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||||
|
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||||
|
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||||
|
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
|
||||||
|
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
|
||||||
|
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
|
||||||
|
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
|
||||||
|
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
|
||||||
|
(PipelineEventType.RUN_END, AssistSatelliteState.RESPONDING),
|
||||||
|
):
|
||||||
|
kwargs["event_callback"](PipelineEvent(event_type, {}))
|
||||||
|
state = hass.states.get(ENTITY_ID)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == expected_state, event_type
|
||||||
|
|
||||||
|
entity.tts_response_finished()
|
||||||
|
state = hass.states.get(ENTITY_ID)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
181
tests/components/assist_satellite/test_websocket_api.py
Normal file
181
tests/components/assist_satellite/test_websocket_api.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
"""Test the Assist Satellite websocket API."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from unittest.mock import ANY, patch
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
)
|
||||||
|
from homeassistant.components.assist_satellite import AssistSatelliteEntityFeature
|
||||||
|
from homeassistant.components.media_source import PlayMedia
|
||||||
|
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
ENTITY_ID = "assist_satellite.test_entity"
|
||||||
|
|
||||||
|
|
||||||
|
async def audio_stream() -> AsyncIterable[bytes]:
|
||||||
|
"""Empty audio stream."""
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test assist_satellite/intercept_wake_word command."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.validate",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_recognize_intent",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_text_to_speech",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch.object(entity, "on_pipeline_event") as mock_on_pipeline_event,
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{"type": "assist_satellite/intercept_wake_word", "entity_id": ENTITY_ID}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for interception to start
|
||||||
|
while not entity.is_intercepting_wake_word:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
# Start a pipeline with a wake word
|
||||||
|
await entity._async_accept_pipeline_from_satellite(
|
||||||
|
audio_stream=audio_stream(),
|
||||||
|
start_stage=PipelineStage.STT,
|
||||||
|
end_stage=PipelineStage.TTS,
|
||||||
|
wake_word_phrase="test wake word",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that wake word was intercepted
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response["success"]
|
||||||
|
assert response["result"] == {"wake_word_phrase": "test wake word"}
|
||||||
|
|
||||||
|
# Verify that only run end event was sent to pipeline
|
||||||
|
mock_on_pipeline_event.assert_called_once_with(
|
||||||
|
PipelineEvent(PipelineEventType.RUN_END, data=None, timestamp=ANY)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_announce_not_supported(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test assist_satellite/announce command with an entity that doesn't support announcements."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
entity, "_attr_supported_features", AssistSatelliteEntityFeature(0)
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/announce",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
"media_id": "test media id",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"]["code"] == ERR_NOT_SUPPORTED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_announce_media_id(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test assist_satellite/announce command with media id."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
entity, "_internal_async_announce"
|
||||||
|
) as mock_internal_async_announce,
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/announce",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
"media_id": "test media id",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response["success"]
|
||||||
|
|
||||||
|
# Verify media id was passed through
|
||||||
|
mock_internal_async_announce.assert_called_once_with("test media id")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_announce_text(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test assist_satellite/announce command with text."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||||
|
return_value="",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.media_source.async_resolve_media",
|
||||||
|
return_value=PlayMedia(url="test media id", mime_type=""),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
|
||||||
|
return_value="test media id",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
entity, "_internal_async_announce"
|
||||||
|
) as mock_internal_async_announce,
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/announce",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
"text": "test text",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response["success"]
|
||||||
|
|
||||||
|
# Verify media id was passed through
|
||||||
|
mock_internal_async_announce.assert_called_once_with("test media id")
|
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
File diff suppressed because one or more lines are too long
|
@ -3,15 +3,26 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
from voip_utils import CallInfo
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, voip
|
from homeassistant.components import assist_pipeline, assist_satellite, voip
|
||||||
from homeassistant.components.voip.devices import VoIPDevice
|
from homeassistant.components.assist_satellite import (
|
||||||
|
AssistSatelliteEntity,
|
||||||
|
AssistSatelliteState,
|
||||||
|
)
|
||||||
|
from homeassistant.components.voip import HassVoipDatagramProtocol
|
||||||
|
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
||||||
|
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
||||||
|
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
||||||
|
from homeassistant.const import STATE_OFF, STATE_ON, Platform
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||||
|
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
|
||||||
return wav_io.getvalue()
|
return wav_io.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def async_get_satellite_entity(
|
||||||
|
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||||
|
) -> AssistSatelliteEntity | None:
|
||||||
|
"""Get Assist satellite entity."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||||
|
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
|
||||||
|
)
|
||||||
|
if satellite_entity_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||||
|
assist_satellite.DOMAIN
|
||||||
|
]
|
||||||
|
return component.get_entity(satellite_entity_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_is_valid_call(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
call_info: CallInfo,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a call is now allowed from an unknown device."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
protocol = HassVoipDatagramProtocol(hass, voip_devices)
|
||||||
|
assert not protocol.is_valid_call(call_info)
|
||||||
|
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
allowed_call_entity_id = ent_reg.async_get_entity_id(
|
||||||
|
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
|
||||||
|
)
|
||||||
|
assert allowed_call_entity_id is not None
|
||||||
|
state = hass.states.get(allowed_call_entity_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_OFF
|
||||||
|
|
||||||
|
# Allow calls
|
||||||
|
hass.states.async_set(allowed_call_entity_id, STATE_ON)
|
||||||
|
assert protocol.is_valid_call(call_info)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_calls_not_allowed(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
call_info: CallInfo,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a pre-recorded message is played when calls aren't allowed."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
|
||||||
|
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||||
|
assert protocol.file_name == "problem.pcm"
|
||||||
|
|
||||||
|
# Test the playback
|
||||||
|
done = asyncio.Event()
|
||||||
|
played_audio_bytes = b""
|
||||||
|
|
||||||
|
def send_audio(audio_bytes: bytes, **kwargs):
|
||||||
|
nonlocal played_audio_bytes
|
||||||
|
|
||||||
|
# Should be problem.pcm from components/voip
|
||||||
|
played_audio_bytes = audio_bytes
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
protocol.transport = Mock()
|
||||||
|
protocol.loop_delay = 0
|
||||||
|
with patch.object(protocol, "send_audio", send_audio):
|
||||||
|
protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
assert sum(played_audio_bytes) > 0
|
||||||
|
assert played_audio_bytes == snapshot()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_not_found(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
call_info: CallInfo,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a pre-recorded message is played when a pipeline isn't found."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
|
||||||
|
):
|
||||||
|
protocol: PreRecordMessageProtocol = make_protocol(
|
||||||
|
hass, voip_devices, call_info
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||||
|
assert protocol.file_name == "problem.pcm"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_prepared(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
call_info: CallInfo,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that satellite is prepared for a call."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
pipeline = assist_pipeline.Pipeline(
|
||||||
|
conversation_engine="test",
|
||||||
|
conversation_language="en",
|
||||||
|
language="en",
|
||||||
|
name="test",
|
||||||
|
stt_engine="test",
|
||||||
|
stt_language="en",
|
||||||
|
tts_engine="test",
|
||||||
|
tts_language="en",
|
||||||
|
tts_voice=None,
|
||||||
|
wake_word_entity=None,
|
||||||
|
wake_word_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.voip.voip.async_get_pipeline",
|
||||||
|
return_value=pipeline,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
protocol = make_protocol(hass, voip_devices, call_info)
|
||||||
|
assert protocol == satellite
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline(
|
async def test_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
voip_device: VoIPDevice,
|
voip_device: VoIPDevice,
|
||||||
|
call_info: CallInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that pipeline function is called from RTP protocol."""
|
"""Test that pipeline function is called from RTP protocol."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def process_10ms(self, chunk):
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
"""Anything non-zero is speech."""
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
if sum(chunk) > 0:
|
voip_user_id = satellite.config_entry.data["user"]
|
||||||
return 1
|
assert voip_user_id
|
||||||
|
|
||||||
return 0
|
# Satellite is muted until a call begins
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
|
|
||||||
# Used to test that audio queue is cleared before pipeline starts
|
# Used to test that audio queue is cleared before pipeline starts
|
||||||
bad_chunk = bytes([1, 2, 3, 4])
|
bad_chunk = bytes([1, 2, 3, 4])
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
|
||||||
|
):
|
||||||
|
assert context.user_id == voip_user_id
|
||||||
assert device_id == voip_device.device_id
|
assert device_id == voip_device.device_id
|
||||||
|
|
||||||
stt_stream = kwargs["stt_stream"]
|
stt_stream = kwargs["stt_stream"]
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
async for _chunk in stt_stream:
|
in_command = False
|
||||||
|
async for chunk in stt_stream:
|
||||||
# Stream will end when VAD detects end of "speech"
|
# Stream will end when VAD detects end of "speech"
|
||||||
assert _chunk != bad_chunk
|
assert chunk != bad_chunk
|
||||||
|
if sum(chunk) > 0:
|
||||||
|
in_command = True
|
||||||
|
elif in_command:
|
||||||
|
break # done with command
|
||||||
|
|
||||||
# Test empty data
|
# Test empty data
|
||||||
event_callback(
|
event_callback(
|
||||||
|
@ -71,6 +229,38 @@ async def test_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.STT_START,
|
||||||
|
data={"engine": "test", "metadata": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||||
|
|
||||||
|
# Fake STT result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "fake-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.INTENT_START,
|
||||||
|
data={
|
||||||
|
"engine": "test",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"intent_input": "fake-text",
|
||||||
|
"conversation_id": None,
|
||||||
|
"device_id": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||||
|
|
||||||
# Fake intent result
|
# Fake intent result
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
|
@ -83,6 +273,21 @@ async def test_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fake tts result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||||
|
data={
|
||||||
|
"engine": "test",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"voice": "test",
|
||||||
|
"tts_input": "fake-text",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
|
|
||||||
# Proceed with media output
|
# Proceed with media output
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
|
@ -91,6 +296,18 @@ async def test_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.RUN_END
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
|
def tts_response_finished():
|
||||||
|
original_tts_response_finished()
|
||||||
|
done.set()
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
async def async_get_media_source_audio(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
media_source_id: str,
|
media_source_id: str,
|
||||||
|
@ -100,102 +317,56 @@ async def test_pipeline(
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"pymicro_vad.MicroVad.Process10ms",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=process_10ms,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||||
new=async_get_media_source_audio,
|
new=async_get_media_source_audio,
|
||||||
),
|
),
|
||||||
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
):
|
):
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
satellite._tones = Tones(0)
|
||||||
hass,
|
satellite.transport = Mock()
|
||||||
hass.config.language,
|
|
||||||
voip_device,
|
satellite.connection_made(satellite.transport)
|
||||||
Context(),
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
opus_payload_type=123,
|
|
||||||
listening_tone_enabled=False,
|
|
||||||
processing_tone_enabled=False,
|
|
||||||
error_tone_enabled=False,
|
|
||||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
|
|
||||||
)
|
|
||||||
rtp_protocol.transport = Mock()
|
|
||||||
|
|
||||||
# Ensure audio queue is cleared before pipeline starts
|
# Ensure audio queue is cleared before pipeline starts
|
||||||
rtp_protocol._audio_queue.put_nowait(bad_chunk)
|
satellite._audio_queue.put_nowait(bad_chunk)
|
||||||
|
|
||||||
def send_audio(*args, **kwargs):
|
def send_audio(*args, **kwargs):
|
||||||
# Test finished successfully
|
# Don't send audio
|
||||||
done.set()
|
pass
|
||||||
|
|
||||||
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
satellite.send_audio = Mock(side_effect=send_audio)
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# "speech"
|
# "speech"
|
||||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes aggressive VAD sensitivity)
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|
||||||
|
# Finished speaking
|
||||||
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
"""Test timeout during pipeline run."""
|
|
||||||
assert await async_setup_component(hass, "voip", {})
|
|
||||||
|
|
||||||
done = asyncio.Event()
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
|
|
||||||
return_value=True,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
|
||||||
hass,
|
|
||||||
hass.config.language,
|
|
||||||
voip_device,
|
|
||||||
Context(),
|
|
||||||
opus_payload_type=123,
|
|
||||||
pipeline_timeout=0.001,
|
|
||||||
listening_tone_enabled=False,
|
|
||||||
processing_tone_enabled=False,
|
|
||||||
error_tone_enabled=False,
|
|
||||||
)
|
|
||||||
transport = Mock(spec=["close"])
|
|
||||||
rtp_protocol.connection_made(transport)
|
|
||||||
|
|
||||||
# Closing the transport will cause the test to succeed
|
|
||||||
transport.close.side_effect = done.set
|
|
||||||
|
|
||||||
# silence
|
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
# Wait for mock pipeline to time out
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
await done.wait()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
async def test_stt_stream_timeout(
|
||||||
|
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
|
||||||
|
) -> None:
|
||||||
"""Test timeout in STT stream during pipeline run."""
|
"""Test timeout in STT stream during pipeline run."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
satellite._tones = Tones(0)
|
||||||
hass,
|
satellite._audio_chunk_timeout = 0.001
|
||||||
hass.config.language,
|
|
||||||
voip_device,
|
|
||||||
Context(),
|
|
||||||
opus_payload_type=123,
|
|
||||||
audio_timeout=0.001,
|
|
||||||
listening_tone_enabled=False,
|
|
||||||
processing_tone_enabled=False,
|
|
||||||
error_tone_enabled=False,
|
|
||||||
)
|
|
||||||
transport = Mock(spec=["close"])
|
transport = Mock(spec=["close"])
|
||||||
rtp_protocol.connection_made(transport)
|
satellite.connection_made(transport)
|
||||||
|
|
||||||
# Closing the transport will cause the test to succeed
|
# Closing the transport will cause the test to succeed
|
||||||
transport.close.side_effect = done.set
|
transport.close.side_effect = done.set
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# Wait for mock pipeline to time out
|
# Wait for mock pipeline to time out
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
||||||
|
|
||||||
async def test_tts_timeout(
|
async def test_tts_timeout(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
voip_device: VoIPDevice,
|
voip_device: VoIPDevice,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that TTS will time out based on its length."""
|
"""Test that TTS will time out based on its length."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def process_10ms(self, chunk):
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
"""Anything non-zero is speech."""
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
if sum(chunk) > 0:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
stt_stream = kwargs["stt_stream"]
|
stt_stream = kwargs["stt_stream"]
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
async for _chunk in stt_stream:
|
in_command = False
|
||||||
# Stream will end when VAD detects end of "speech"
|
async for chunk in stt_stream:
|
||||||
pass
|
if sum(chunk) > 0:
|
||||||
|
in_command = True
|
||||||
|
elif in_command:
|
||||||
|
break # done with command
|
||||||
|
|
||||||
|
# Fake STT result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "fake-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Fake intent result
|
# Fake intent result
|
||||||
event_callback(
|
event_callback(
|
||||||
|
@ -278,15 +448,7 @@ async def test_tts_timeout(
|
||||||
|
|
||||||
tone_bytes = bytes([1, 2, 3, 4])
|
tone_bytes = bytes([1, 2, 3, 4])
|
||||||
|
|
||||||
def send_audio(audio_bytes, **kwargs):
|
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||||
if audio_bytes == tone_bytes:
|
|
||||||
# Not TTS
|
|
||||||
return
|
|
||||||
|
|
||||||
# Block here to force a timeout in _send_tts
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
async def async_send_audio(audio_bytes, **kwargs):
|
|
||||||
if audio_bytes == tone_bytes:
|
if audio_bytes == tone_bytes:
|
||||||
# Not TTS
|
# Not TTS
|
||||||
return
|
return
|
||||||
|
@ -303,37 +465,22 @@ async def test_tts_timeout(
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"pymicro_vad.MicroVad.Process10ms",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=process_10ms,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||||
new=async_get_media_source_audio,
|
new=async_get_media_source_audio,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
satellite._tts_extra_timeout = 0.001
|
||||||
hass,
|
for tone in Tones:
|
||||||
hass.config.language,
|
satellite._tone_bytes[tone] = tone_bytes
|
||||||
voip_device,
|
|
||||||
Context(),
|
|
||||||
opus_payload_type=123,
|
|
||||||
tts_extra_timeout=0.001,
|
|
||||||
listening_tone_enabled=True,
|
|
||||||
processing_tone_enabled=True,
|
|
||||||
error_tone_enabled=True,
|
|
||||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
|
|
||||||
)
|
|
||||||
rtp_protocol._tone_bytes = tone_bytes
|
|
||||||
rtp_protocol._processing_bytes = tone_bytes
|
|
||||||
rtp_protocol._error_bytes = tone_bytes
|
|
||||||
rtp_protocol.transport = Mock()
|
|
||||||
rtp_protocol.send_audio = Mock()
|
|
||||||
|
|
||||||
original_send_tts = rtp_protocol._send_tts
|
satellite.transport = Mock()
|
||||||
|
satellite.send_audio = Mock()
|
||||||
|
|
||||||
|
original_send_tts = satellite._send_tts
|
||||||
|
|
||||||
async def send_tts(*args, **kwargs):
|
async def send_tts(*args, **kwargs):
|
||||||
# Call original then end test successfully
|
# Call original then end test successfully
|
||||||
|
@ -342,17 +489,17 @@ async def test_tts_timeout(
|
||||||
|
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# "speech"
|
# "speech"
|
||||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -361,26 +508,34 @@ async def test_tts_timeout(
|
||||||
|
|
||||||
async def test_tts_wrong_extension(
|
async def test_tts_wrong_extension(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
voip_device: VoIPDevice,
|
voip_device: VoIPDevice,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that TTS will only stream WAV audio."""
|
"""Test that TTS will only stream WAV audio."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def process_10ms(self, chunk):
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
"""Anything non-zero is speech."""
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
if sum(chunk) > 0:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
stt_stream = kwargs["stt_stream"]
|
stt_stream = kwargs["stt_stream"]
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
async for _chunk in stt_stream:
|
in_command = False
|
||||||
# Stream will end when VAD detects end of "speech"
|
async for chunk in stt_stream:
|
||||||
pass
|
if sum(chunk) > 0:
|
||||||
|
in_command = True
|
||||||
|
elif in_command:
|
||||||
|
break # done with command
|
||||||
|
|
||||||
|
# Fake STT result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "fake-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Fake intent result
|
# Fake intent result
|
||||||
event_callback(
|
event_callback(
|
||||||
|
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"pymicro_vad.MicroVad.Process10ms",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=process_10ms,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||||
new=async_get_media_source_audio,
|
new=async_get_media_source_audio,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
satellite.transport = Mock()
|
||||||
hass,
|
|
||||||
hass.config.language,
|
|
||||||
voip_device,
|
|
||||||
Context(),
|
|
||||||
opus_payload_type=123,
|
|
||||||
)
|
|
||||||
rtp_protocol.transport = Mock()
|
|
||||||
|
|
||||||
original_send_tts = rtp_protocol._send_tts
|
original_send_tts = satellite._send_tts
|
||||||
|
|
||||||
async def send_tts(*args, **kwargs):
|
async def send_tts(*args, **kwargs):
|
||||||
# Call original then end test successfully
|
# Call original then end test successfully
|
||||||
|
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
|
||||||
|
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# "speech"
|
# "speech"
|
||||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||||
|
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
|
||||||
|
|
||||||
async def test_tts_wrong_wav_format(
|
async def test_tts_wrong_wav_format(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
voip_device: VoIPDevice,
|
voip_device: VoIPDevice,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that TTS will only stream WAV audio with a specific format."""
|
"""Test that TTS will only stream WAV audio with a specific format."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def process_10ms(self, chunk):
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
"""Anything non-zero is speech."""
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
if sum(chunk) > 0:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
stt_stream = kwargs["stt_stream"]
|
stt_stream = kwargs["stt_stream"]
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
async for _chunk in stt_stream:
|
in_command = False
|
||||||
# Stream will end when VAD detects end of "speech"
|
async for chunk in stt_stream:
|
||||||
pass
|
if sum(chunk) > 0:
|
||||||
|
in_command = True
|
||||||
|
elif in_command:
|
||||||
|
break # done with command
|
||||||
|
|
||||||
|
# Fake STT result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "fake-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Fake intent result
|
# Fake intent result
|
||||||
event_callback(
|
event_callback(
|
||||||
|
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"pymicro_vad.MicroVad.Process10ms",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=process_10ms,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||||
new=async_get_media_source_audio,
|
new=async_get_media_source_audio,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
satellite.transport = Mock()
|
||||||
hass,
|
|
||||||
hass.config.language,
|
|
||||||
voip_device,
|
|
||||||
Context(),
|
|
||||||
opus_payload_type=123,
|
|
||||||
)
|
|
||||||
rtp_protocol.transport = Mock()
|
|
||||||
|
|
||||||
original_send_tts = rtp_protocol._send_tts
|
original_send_tts = satellite._send_tts
|
||||||
|
|
||||||
async def send_tts(*args, **kwargs):
|
async def send_tts(*args, **kwargs):
|
||||||
# Call original then end test successfully
|
# Call original then end test successfully
|
||||||
|
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
|
||||||
|
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# "speech"
|
# "speech"
|
||||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||||
|
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
|
||||||
|
|
||||||
async def test_empty_tts_output(
|
async def test_empty_tts_output(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
voip_device: VoIPDevice,
|
voip_device: VoIPDevice,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that TTS will not stream when output is empty."""
|
"""Test that TTS will not stream when output is empty."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def process_10ms(self, chunk):
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
"""Anything non-zero is speech."""
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
if sum(chunk) > 0:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
stt_stream = kwargs["stt_stream"]
|
stt_stream = kwargs["stt_stream"]
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
async for _chunk in stt_stream:
|
in_command = False
|
||||||
# Stream will end when VAD detects end of "speech"
|
async for chunk in stt_stream:
|
||||||
pass
|
if sum(chunk) > 0:
|
||||||
|
in_command = True
|
||||||
|
elif in_command:
|
||||||
|
break # done with command
|
||||||
|
|
||||||
|
# Fake STT result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "fake-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Fake intent result
|
# Fake intent result
|
||||||
event_callback(
|
event_callback(
|
||||||
|
@ -605,37 +754,78 @@ async def test_empty_tts_output(
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"pymicro_vad.MicroVad.Process10ms",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=process_10ms,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||||
) as mock_send_tts,
|
) as mock_send_tts,
|
||||||
):
|
):
|
||||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
satellite.transport = Mock()
|
||||||
hass,
|
|
||||||
hass.config.language,
|
|
||||||
voip_device,
|
|
||||||
Context(),
|
|
||||||
opus_payload_type=123,
|
|
||||||
)
|
|
||||||
rtp_protocol.transport = Mock()
|
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# "speech"
|
# "speech"
|
||||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||||
|
|
||||||
# Wait for mock pipeline to finish
|
# Wait for mock pipeline to finish
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await rtp_protocol._tts_done.wait()
|
await satellite._tts_done.wait()
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a pipeline error causes the error tone to be played."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
played_audio_bytes = b""
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
# Fake error
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.ERROR,
|
||||||
|
data={"code": "error-code", "message": "error message"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||||
|
nonlocal played_audio_bytes
|
||||||
|
|
||||||
|
# Should be error.pcm from components/voip
|
||||||
|
played_audio_bytes = audio_bytes
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
satellite._tones = Tones.ERROR
|
||||||
|
satellite.transport = Mock()
|
||||||
|
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# Wait for error tone to be played
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
assert sum(played_audio_bytes) > 0
|
||||||
|
assert played_audio_bytes == snapshot()
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||||
from homeassistant.components.wyoming import DOMAIN
|
from homeassistant.components.wyoming import DOMAIN
|
||||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline
|
from homeassistant.components import assist_pipeline
|
||||||
|
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||||
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
|
||||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
|
|
@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
|
||||||
timer_handle.cancel()
|
timer_handle.cancel()
|
||||||
timer_handle2.cancel()
|
timer_handle2.cancel()
|
||||||
timer_handle3.cancel()
|
timer_handle3.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_queue_to_iterable() -> None:
|
||||||
|
"""Test queue_to_iterable."""
|
||||||
|
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||||
|
expected_items = list(range(10))
|
||||||
|
|
||||||
|
for i in expected_items:
|
||||||
|
await queue.put(i)
|
||||||
|
|
||||||
|
# Will terminate the stream
|
||||||
|
await queue.put(None)
|
||||||
|
|
||||||
|
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
|
||||||
|
|
||||||
|
assert expected_items == actual_items
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
assert queue.empty()
|
||||||
|
|
||||||
|
# Time out on first item
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||||
|
# Should time out very quickly
|
||||||
|
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Check timeout on second item
|
||||||
|
assert queue.empty()
|
||||||
|
await queue.put(12345)
|
||||||
|
|
||||||
|
# Time out on second item
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||||
|
# Should time out very quickly
|
||||||
|
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||||
|
if item != 12345:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
assert queue.empty()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue