Compare commits
13 commits
dev
...
synesthesi
Author | SHA1 | Date | |
---|---|---|---|
|
cfef78a3e2 | ||
|
5b872b1511 | ||
|
35ec60d85f | ||
|
7bcef1be60 | ||
|
d52153331d | ||
|
77dfa41f8a | ||
|
8cfd4e7f17 | ||
|
c15e02eee3 | ||
|
dd3cd65bfc | ||
|
ee0c649687 | ||
|
4e27d8ec78 | ||
|
a1db430249 | ||
|
033bc1bbe5 |
43 changed files with 3357 additions and 2361 deletions
|
@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/aseko_pool_live/ @milanmeu
|
||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/homeassistant/components/assist_satellite/ @synesthesiam
|
||||
/tests/components/assist_satellite/ @synesthesiam
|
||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||
/homeassistant/components/atag/ @MatsNL
|
||||
|
|
|
@ -17,6 +17,7 @@ from .const import (
|
|||
DATA_LAST_WAKE_UP,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
OPTION_PREFERRED,
|
||||
SAMPLE_CHANNELS,
|
||||
SAMPLE_RATE,
|
||||
SAMPLE_WIDTH,
|
||||
|
@ -58,6 +59,7 @@ __all__ = (
|
|||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
"EVENT_RECORDING",
|
||||
"OPTION_PREFERRED",
|
||||
"SAMPLES_PER_CHUNK",
|
||||
"SAMPLE_RATE",
|
||||
"SAMPLE_WIDTH",
|
||||
|
|
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
|||
MS_PER_CHUNK = 10
|
||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
|
|
@ -504,7 +504,7 @@ class AudioSettings:
|
|||
is_vad_enabled: bool = True
|
||||
"""True if VAD is used to determine the end of the voice command."""
|
||||
|
||||
silence_seconds: float = 0.5
|
||||
silence_seconds: float = 0.7
|
||||
"""Seconds of silence after voice command has ended."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -906,6 +906,8 @@ class PipelineRun:
|
|||
metadata,
|
||||
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||
)
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
raise # expected
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
raise SpeechToTextError(
|
||||
|
|
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, OPTION_PREFERRED
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
||||
@callback
|
||||
def get_chosen_pipeline(
|
||||
|
|
65
homeassistant/components/assist_satellite/__init__.py
Normal file
65
homeassistant/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
"""Base class for assist satellite entities."""
|
||||
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityDescription",
|
||||
"AssistSatelliteEntityFeature",
|
||||
"AssistSatelliteState",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
await component.async_setup(config)
|
||||
|
||||
component.async_register_entity_service(
|
||||
"announce",
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Optional("message"): str,
|
||||
vol.Optional("media_id"): str,
|
||||
}
|
||||
),
|
||||
cv.has_at_least_one_key("message", "media_id"),
|
||||
),
|
||||
"async_internal_announce",
|
||||
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||
)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
3
homeassistant/components/assist_satellite/const.py
Normal file
3
homeassistant/components/assist_satellite/const.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Constants for assist satellite."""
|
||||
|
||||
DOMAIN = "assist_satellite"
|
302
homeassistant/components/assist_satellite/entity.py
Normal file
302
homeassistant/components/assist_satellite/entity.py
Normal file
|
@ -0,0 +1,302 @@
|
|||
"""Assist satellite entity."""
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final
|
||||
|
||||
from homeassistant.components import media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .errors import AssistSatelliteError, SatelliteBusyError
|
||||
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||
"""A class that describes binary sensor entities."""
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
entity_description: AssistSatelliteEntityDescription
|
||||
_attr_should_poll = False
|
||||
_attr_state: AssistSatelliteState | None = None
|
||||
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
||||
_attr_pipeline_entity_id: str | None = None
|
||||
_attr_vad_sensitivity_entity_id: str | None = None
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_run_has_tts: bool = False
|
||||
_is_announcing = False
|
||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Entity ID of the pipeline to use for the next conversation."""
|
||||
return self._attr_pipeline_entity_id
|
||||
|
||||
@property
|
||||
def vad_sensitivity_entity_id(self) -> str | None:
|
||||
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
return self._attr_vad_sensitivity_entity_id
|
||||
|
||||
async def async_intercept_wake_word(self) -> str | None:
|
||||
"""Intercept the next wake word from the satellite.
|
||||
|
||||
Returns the detected wake word phrase or None.
|
||||
"""
|
||||
if self._wake_word_intercept_future is not None:
|
||||
raise SatelliteBusyError("Wake word interception already in progress")
|
||||
|
||||
# Will cause next wake word to be intercepted in
|
||||
# async_accept_pipeline_from_satellite
|
||||
self._wake_word_intercept_future = asyncio.Future()
|
||||
|
||||
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
||||
|
||||
try:
|
||||
return await self._wake_word_intercept_future
|
||||
finally:
|
||||
self._wake_word_intercept_future = None
|
||||
|
||||
async def async_internal_announce(
|
||||
self,
|
||||
message: str | None = None,
|
||||
media_id: str | None = None,
|
||||
) -> None:
|
||||
"""Play an announcement on the satellite.
|
||||
|
||||
If media_id is not provided, message is synthesized to
|
||||
audio with the selected pipeline.
|
||||
|
||||
Calls async_announce with media id.
|
||||
"""
|
||||
if message is None:
|
||||
message = ""
|
||||
|
||||
if not media_id:
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline()
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
message,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
media_id,
|
||||
None,
|
||||
)
|
||||
media_id = media.url
|
||||
|
||||
# Resolve to full URL
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
||||
self._is_announcing = True
|
||||
|
||||
try:
|
||||
# Block until announcement is finished
|
||||
await self.async_announce(message, media_id)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
wake_word_phrase: str | None = None,
|
||||
) -> None:
|
||||
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||
if self._wake_word_intercept_future and start_stage in (
|
||||
PipelineStage.WAKE_WORD,
|
||||
PipelineStage.STT,
|
||||
):
|
||||
if start_stage == PipelineStage.WAKE_WORD:
|
||||
self._wake_word_intercept_future.set_exception(
|
||||
AssistSatelliteError(
|
||||
"Only on-device wake words currently supported"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Intercepting wake word and immediately end pipeline
|
||||
_LOGGER.debug(
|
||||
"Intercepted wake word: %s (entity_id=%s)",
|
||||
wake_word_phrase,
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
if wake_word_phrase is None:
|
||||
self._wake_word_intercept_future.set_exception(
|
||||
AssistSatelliteError("No wake word phrase provided")
|
||||
)
|
||||
else:
|
||||
self._wake_word_intercept_future.set_result(wake_word_phrase)
|
||||
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||
return
|
||||
|
||||
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||
|
||||
# Refresh context if necessary
|
||||
if (
|
||||
(self._context is None)
|
||||
or (self._context_set is None)
|
||||
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||
):
|
||||
self.async_set_context(Context())
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = ulid.ulid()
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._run_has_tts = False
|
||||
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output="wav",
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
||||
@callback
|
||||
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type is PipelineEventType.WAKE_WORD_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
elif event.type is PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
elif event.type is PipelineEventType.RUN_END:
|
||||
if not self._run_has_tts:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
self.on_pipeline_event(event)
|
||||
|
||||
@callback
|
||||
def _set_state(self, state: AssistSatelliteState):
|
||||
"""Set the entity's state."""
|
||||
self._attr_state = state
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def tts_response_finished(self) -> None:
|
||||
"""Tell entity that the text-to-speech response has finished playing."""
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
@callback
|
||||
def _resolve_pipeline(self) -> str | None:
|
||||
"""Resolve pipeline from select entity to id."""
|
||||
if not (pipeline_entity_id := self.pipeline_entity_id):
|
||||
return None
|
||||
|
||||
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
|
||||
raise RuntimeError("Pipeline entity not found")
|
||||
|
||||
if pipeline_entity_state.state != OPTION_PREFERRED:
|
||||
# Resolve pipeline by name
|
||||
for pipeline in async_get_pipelines(self.hass):
|
||||
if pipeline.name == pipeline_entity_state.state:
|
||||
return pipeline.id
|
||||
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _resolve_vad_sensitivity(self) -> float:
|
||||
"""Resolve VAD sensitivity from select entity to enum."""
|
||||
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||
|
||||
if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
|
||||
if (
|
||||
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
||||
) is None:
|
||||
raise RuntimeError("VAD sensitivity entity not found")
|
||||
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
return vad.VadSensitivity.to_seconds(vad_sensitivity)
|
11
homeassistant/components/assist_satellite/errors.py
Normal file
11
homeassistant/components/assist_satellite/errors.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""Errors for assist satellite."""
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
|
||||
class AssistSatelliteError(HomeAssistantError):
|
||||
"""Base class for assist satellite errors."""
|
||||
|
||||
|
||||
class SatelliteBusyError(AssistSatelliteError):
|
||||
"""Satellite is busy and cannot handle the request."""
|
12
homeassistant/components/assist_satellite/icons.json
Normal file
12
homeassistant/components/assist_satellite/icons.json
Normal file
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"default": "mdi:account-voice"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"announce": {
|
||||
"service": "mdi:bullhorn"
|
||||
}
|
||||
}
|
||||
}
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"domain": "assist_satellite",
|
||||
"name": "Assist Satellite",
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": false,
|
||||
"dependencies": ["assist_pipeline", "stt", "tts"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||
"integration_type": "entity"
|
||||
}
|
26
homeassistant/components/assist_satellite/models.py
Normal file
26
homeassistant/components/assist_satellite/models.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""Models for assist satellite."""
|
||||
|
||||
from enum import IntFlag, StrEnum
|
||||
|
||||
|
||||
class AssistSatelliteState(StrEnum):
|
||||
"""Valid states of an Assist satellite entity."""
|
||||
|
||||
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||
|
||||
LISTENING_COMMAND = "listening_command"
|
||||
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||
|
||||
PROCESSING = "processing"
|
||||
"""Home Assistant is processing the voice command."""
|
||||
|
||||
RESPONDING = "responding"
|
||||
"""Device is speaking the response."""
|
||||
|
||||
|
||||
class AssistSatelliteEntityFeature(IntFlag):
|
||||
"""Supported features of Assist satellite entity."""
|
||||
|
||||
ANNOUNCE = 1
|
||||
"""Device supports remotely triggered announcements."""
|
16
homeassistant/components/assist_satellite/services.yaml
Normal file
16
homeassistant/components/assist_satellite/services.yaml
Normal file
|
@ -0,0 +1,16 @@
|
|||
announce:
|
||||
target:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
fields:
|
||||
message:
|
||||
required: false
|
||||
example: "Time to wake up!"
|
||||
selector:
|
||||
text:
|
||||
media_id:
|
||||
required: false
|
||||
selector:
|
||||
text:
|
30
homeassistant/components/assist_satellite/strings.json
Normal file
30
homeassistant/components/assist_satellite/strings.json
Normal file
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"title": "Assist satellite",
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"name": "Assist satellite",
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"announce": {
|
||||
"name": "Announce",
|
||||
"description": "Let the satellite announce a message.",
|
||||
"fields": {
|
||||
"message": {
|
||||
"name": "Message",
|
||||
"description": "The message to announce."
|
||||
},
|
||||
"media_id": {
|
||||
"name": "Media ID",
|
||||
"description": "The media ID to announce instead of using text-to-speech."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
46
homeassistant/components/assist_satellite/websocket_api.py
Normal file
46
homeassistant/components/assist_satellite/websocket_api.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
"""Assist satellite Websocket API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_satellite/intercept_wake_word",
|
||||
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.async_response
|
||||
async def websocket_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Intercept the next wake word from a satellite."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
satellite = component.get_entity(msg["entity_id"])
|
||||
if satellite is None:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
|
||||
)
|
||||
return
|
||||
|
||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
529
homeassistant/components/esphome/assist_satellite.py
Normal file
529
homeassistant/components/esphome/assist_satellite.py
Normal file
|
@ -0,0 +1,529 @@
|
|||
"""Support for assist satellites in ESPHome."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
from typing import Any, cast
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantFeature,
|
||||
VoiceAssistantTimerEventType,
|
||||
)
|
||||
|
||||
from homeassistant.components import assist_satellite, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
)
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import EntityCategory, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import EsphomeAssistEntity
|
||||
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||
VoiceAssistantEventType, PipelineEventType
|
||||
] = EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
|
||||
}
|
||||
)
|
||||
|
||||
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
|
||||
EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
entry: ESPHomeConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Assist satellite entity."""
|
||||
entry_data = entry.runtime_data
|
||||
assert entry_data.device_info is not None
|
||||
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
entry_data.api_version
|
||||
):
|
||||
async_add_entities(
|
||||
[
|
||||
EsphomeAssistSatellite(hass, entry, entry_data),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EsphomeAssistSatellite(
|
||||
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
|
||||
):
|
||||
"""Satellite running ESPHome."""
|
||||
|
||||
entity_description = assist_satellite.AssistSatelliteEntityDescription(
|
||||
key="assist_satellite",
|
||||
translation_key="assist_satellite",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
entry_data: RuntimeEntryData,
|
||||
) -> None:
|
||||
"""Initialize satellite."""
|
||||
super().__init__(entry_data)
|
||||
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.entry_data = entry_data
|
||||
self.cli = self.entry_data.client
|
||||
|
||||
self._is_running: bool = True
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._tts_streaming_task: asyncio.Task | None = None
|
||||
self._udp_server: VoiceAssistantUDPServer | None = None
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
assert self.entry_data.device_info is not None
|
||||
ent_reg = er.async_get(self.hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self.entry_data.device_info.mac_address}-pipeline",
|
||||
)
|
||||
|
||||
@property
|
||||
def vad_sensitivity_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
assert self.entry_data.device_info is not None
|
||||
ent_reg = er.async_get(self.hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
||||
)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
assert self.entry_data.device_info is not None
|
||||
feature_flags = (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if feature_flags & VoiceAssistantFeature.API_AUDIO:
|
||||
# TCP audio
|
||||
self.entry_data.disconnect_callbacks.add(
|
||||
self.cli.subscribe_voice_assistant(
|
||||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
handle_audio=self.handle_audio,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# UDP audio
|
||||
self.entry_data.disconnect_callbacks.add(
|
||||
self.cli.subscribe_voice_assistant(
|
||||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
)
|
||||
)
|
||||
|
||||
if feature_flags & VoiceAssistantFeature.TIMERS:
|
||||
# Device supports timers
|
||||
assert (self.registry_entry is not None) and (
|
||||
self.registry_entry.device_id is not None
|
||||
)
|
||||
self.entry_data.disconnect_callbacks.add(
|
||||
async_register_timer_handler(
|
||||
self.hass, self.registry_entry.device_id, self.handle_timer_event
|
||||
)
|
||||
)
|
||||
|
||||
if feature_flags & VoiceAssistantFeature.ANNOUNCE:
|
||||
# Device supports announcements
|
||||
self._attr_supported_features |= (
|
||||
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
self._set_state(assist_satellite.AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
self._is_running = False
|
||||
self._stop_pipeline()
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
try:
|
||||
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
||||
except KeyError:
|
||||
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
|
||||
return
|
||||
|
||||
data_to_send: dict[str, Any] = {}
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
||||
}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["tts_input"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert event.data is not None
|
||||
if tts_output := event.data["tts_output"]:
|
||||
path = tts_output["url"]
|
||||
url = async_process_play_media_url(self.hass, path)
|
||||
data_to_send = {"url": url}
|
||||
|
||||
assert self.entry_data.device_info is not None
|
||||
feature_flags = (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||
media_id = tts_output["media_id"]
|
||||
self._tts_streaming_task = (
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._stream_tts_audio(media_id),
|
||||
"esphome_voice_assistant_tts",
|
||||
)
|
||||
)
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||
assert event.data is not None
|
||||
if not event.data["wake_word_output"]:
|
||||
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||
data_to_send = {
|
||||
"code": "no_wake_word",
|
||||
"message": "No wake word detected",
|
||||
}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"code": event.data["code"],
|
||||
"message": event.data["message"],
|
||||
}
|
||||
|
||||
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
_LOGGER.debug(
|
||||
"Waiting for announcement to finished (message=%s, media_id=%s)",
|
||||
message,
|
||||
media_id,
|
||||
)
|
||||
await self.cli.wait_voice_assistant_announce(media_id, message)
|
||||
|
||||
async def handle_pipeline_start(
|
||||
self,
|
||||
conversation_id: str,
|
||||
flags: int,
|
||||
audio_settings: VoiceAssistantAudioSettings,
|
||||
wake_word_phrase: str | None,
|
||||
) -> int | None:
|
||||
"""Handle pipeline run request."""
|
||||
# Clear audio queue
|
||||
while not self._audio_queue.empty():
|
||||
await self._audio_queue.get()
|
||||
|
||||
if self._tts_streaming_task is not None:
|
||||
# Cancel current TTS response
|
||||
self._tts_streaming_task.cancel()
|
||||
self._tts_streaming_task = None
|
||||
|
||||
# API or UDP output audio
|
||||
port: int = 0
|
||||
assert self.entry_data.device_info is not None
|
||||
feature_flags = (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
|
||||
feature_flags & VoiceAssistantFeature.API_AUDIO
|
||||
):
|
||||
port = await self._start_udp_server()
|
||||
_LOGGER.debug("Started UDP server on port %s", port)
|
||||
|
||||
# Device triggered pipeline (wake word, etc.)
|
||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||
start_stage = PipelineStage.WAKE_WORD
|
||||
else:
|
||||
start_stage = PipelineStage.STT
|
||||
|
||||
end_stage = PipelineStage.TTS
|
||||
|
||||
# Run the pipeline
|
||||
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self.async_accept_pipeline_from_satellite(
|
||||
audio_stream=self._wrap_audio_stream(),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
),
|
||||
"esphome_assist_satellite_pipeline",
|
||||
)
|
||||
self._pipeline_task.add_done_callback(
|
||||
lambda _future: self.handle_pipeline_finished()
|
||||
)
|
||||
|
||||
return port
|
||||
|
||||
async def handle_audio(self, data: bytes) -> None:
|
||||
"""Handle incoming audio chunk from API."""
|
||||
self._audio_queue.put_nowait(data)
|
||||
|
||||
async def handle_pipeline_stop(self) -> None:
|
||||
"""Handle request for pipeline to stop."""
|
||||
self._stop_pipeline()
|
||||
|
||||
def handle_pipeline_finished(self) -> None:
|
||||
"""Handle when pipeline has finished running."""
|
||||
self.entry_data.async_set_assist_pipeline_state(False)
|
||||
self._stop_udp_server()
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
def handle_timer_event(
|
||||
self, event_type: TimerEventType, timer_info: TimerInfo
|
||||
) -> None:
|
||||
"""Handle timer events."""
|
||||
try:
|
||||
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
|
||||
except KeyError:
|
||||
_LOGGER.debug("Received unknown timer event type: %s", event_type)
|
||||
return
|
||||
|
||||
self.cli.send_voice_assistant_timer_event(
|
||||
native_event_type,
|
||||
timer_info.id,
|
||||
timer_info.name,
|
||||
timer_info.created_seconds,
|
||||
timer_info.seconds_left,
|
||||
timer_info.is_active,
|
||||
)
|
||||
|
||||
async def _stream_tts_audio(
|
||||
self,
|
||||
media_id: str,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
sample_channels: int = 1,
|
||||
samples_per_chunk: int = 512,
|
||||
) -> None:
|
||||
"""Stream TTS audio chunks to device via API or UDP."""
|
||||
self.cli.send_voice_assistant_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
|
||||
)
|
||||
|
||||
try:
|
||||
if not self._is_running:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
|
||||
return
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
if (
|
||||
(wav_file.getframerate() != sample_rate)
|
||||
or (wav_file.getsampwidth() != sample_width)
|
||||
or (wav_file.getnchannels() != sample_channels)
|
||||
):
|
||||
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
|
||||
|
||||
while self._is_running:
|
||||
chunk = wav_file.readframes(samples_per_chunk)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if self._udp_server is not None:
|
||||
self._udp_server.send_audio_bytes(chunk)
|
||||
else:
|
||||
self.cli.send_voice_assistant_audio(chunk)
|
||||
|
||||
# Wait for 90% of the duration of the audio that was
|
||||
# sent for it to be played. This will overrun the
|
||||
# device's buffer for very long audio, so using a media
|
||||
# player is preferred.
|
||||
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||
except asyncio.CancelledError:
|
||||
return # Don't trigger state change
|
||||
finally:
|
||||
self.cli.send_voice_assistant_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||
)
|
||||
|
||||
# State change
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks from the queue until None."""
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
def _stop_pipeline(self) -> None:
|
||||
"""Request pipeline to be stopped."""
|
||||
self._audio_queue.put_nowait(None)
|
||||
_LOGGER.debug("Requested pipeline stop")
|
||||
|
||||
async def _start_udp_server(self) -> int:
|
||||
"""Start a UDP server on a random free port."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind(("", 0)) # random free port
|
||||
|
||||
(
|
||||
_transport,
|
||||
protocol,
|
||||
) = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
|
||||
)
|
||||
|
||||
assert isinstance(protocol, VoiceAssistantUDPServer)
|
||||
self._udp_server = protocol
|
||||
|
||||
# Return port
|
||||
return cast(int, sock.getsockname()[1])
|
||||
|
||||
def _stop_udp_server(self) -> None:
|
||||
"""Stop the UDP server if it's running."""
|
||||
if self._udp_server is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._udp_server.close()
|
||||
finally:
|
||||
self._udp_server = None
|
||||
|
||||
_LOGGER.debug("Stopped UDP server")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
"""Receive UDP packets and forward them to the audio queue."""
|
||||
|
||||
transport: asyncio.DatagramTransport | None = None
|
||||
remote_addr: tuple[str, int] | None = None
|
||||
|
||||
def __init__(
|
||||
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize protocol."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._audio_queue = audio_queue
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Store transport for later use."""
|
||||
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Handle incoming UDP packet."""
|
||||
if self.remote_addr is None:
|
||||
self.remote_addr = addr
|
||||
|
||||
self._audio_queue.put_nowait(data)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
"""Handle when a send or receive operation raises an OSError.
|
||||
|
||||
(Other than BlockingIOError or InterruptedError.)
|
||||
"""
|
||||
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||
|
||||
# Stop pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the receiver."""
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
self.remote_addr = None
|
||||
|
||||
def send_audio_bytes(self, data: bytes) -> None:
|
||||
"""Send bytes to the device via UDP."""
|
||||
if self.transport is None:
|
||||
_LOGGER.error("No transport to send audio to")
|
||||
return
|
||||
|
||||
if self.remote_addr is None:
|
||||
_LOGGER.error("No address to send audio to")
|
||||
return
|
||||
|
||||
self.transport.sendto(data, self.remote_addr)
|
|
@ -20,19 +20,17 @@ from aioesphomeapi import (
|
|||
RequiresEncryptionAPIError,
|
||||
UserService,
|
||||
UserServiceArgType,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantFeature,
|
||||
)
|
||||
from awesomeversion import AwesomeVersion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import tag, zeroconf
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_ID,
|
||||
CONF_MODE,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_LOGGING_CHANGED,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
Event,
|
||||
|
@ -73,12 +71,6 @@ from .domain_data import DomainData
|
|||
|
||||
# Import config flow so that it's added to the registry
|
||||
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||
from .voice_assistant import (
|
||||
VoiceAssistantAPIPipeline,
|
||||
VoiceAssistantPipeline,
|
||||
VoiceAssistantUDPPipeline,
|
||||
handle_timer_event,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -149,7 +141,6 @@ class ESPHomeManager:
|
|||
"cli",
|
||||
"device_id",
|
||||
"domain_data",
|
||||
"voice_assistant_pipeline",
|
||||
"reconnect_logic",
|
||||
"zeroconf_instance",
|
||||
"entry_data",
|
||||
|
@ -173,7 +164,6 @@ class ESPHomeManager:
|
|||
self.cli = cli
|
||||
self.device_id: str | None = None
|
||||
self.domain_data = domain_data
|
||||
self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None
|
||||
self.reconnect_logic: ReconnectLogic | None = None
|
||||
self.zeroconf_instance = zeroconf_instance
|
||||
self.entry_data = entry.runtime_data
|
||||
|
@ -338,77 +328,6 @@ class ESPHomeManager:
|
|||
entity_id, attribute, self.hass.states.get(entity_id)
|
||||
)
|
||||
|
||||
def _handle_pipeline_finished(self) -> None:
|
||||
self.entry_data.async_set_assist_pipeline_state(False)
|
||||
|
||||
if self.voice_assistant_pipeline is not None:
|
||||
if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline):
|
||||
self.voice_assistant_pipeline.close()
|
||||
self.voice_assistant_pipeline = None
|
||||
|
||||
async def _handle_pipeline_start(
|
||||
self,
|
||||
conversation_id: str,
|
||||
flags: int,
|
||||
audio_settings: VoiceAssistantAudioSettings,
|
||||
wake_word_phrase: str | None,
|
||||
) -> int | None:
|
||||
"""Start a voice assistant pipeline."""
|
||||
if self.voice_assistant_pipeline is not None:
|
||||
_LOGGER.warning("Previous Voice assistant pipeline was not stopped")
|
||||
self.voice_assistant_pipeline.stop()
|
||||
self.voice_assistant_pipeline = None
|
||||
|
||||
hass = self.hass
|
||||
assert self.entry_data.device_info is not None
|
||||
if (
|
||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
& VoiceAssistantFeature.API_AUDIO
|
||||
):
|
||||
self.voice_assistant_pipeline = VoiceAssistantAPIPipeline(
|
||||
hass,
|
||||
self.entry_data,
|
||||
self.cli.send_voice_assistant_event,
|
||||
self._handle_pipeline_finished,
|
||||
self.cli,
|
||||
)
|
||||
port = 0
|
||||
else:
|
||||
self.voice_assistant_pipeline = VoiceAssistantUDPPipeline(
|
||||
hass,
|
||||
self.entry_data,
|
||||
self.cli.send_voice_assistant_event,
|
||||
self._handle_pipeline_finished,
|
||||
)
|
||||
port = await self.voice_assistant_pipeline.start_server()
|
||||
|
||||
assert self.device_id is not None, "Device ID must be set"
|
||||
hass.async_create_background_task(
|
||||
self.voice_assistant_pipeline.run_pipeline(
|
||||
device_id=self.device_id,
|
||||
conversation_id=conversation_id or None,
|
||||
flags=flags,
|
||||
audio_settings=audio_settings,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
),
|
||||
"esphome.voice_assistant_pipeline.run_pipeline",
|
||||
)
|
||||
|
||||
return port
|
||||
|
||||
async def _handle_pipeline_stop(self) -> None:
|
||||
"""Stop a voice assistant pipeline."""
|
||||
if self.voice_assistant_pipeline is not None:
|
||||
self.voice_assistant_pipeline.stop()
|
||||
|
||||
async def _handle_audio(self, data: bytes) -> None:
|
||||
if self.voice_assistant_pipeline is None:
|
||||
return
|
||||
assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline)
|
||||
self.voice_assistant_pipeline.receive_audio_bytes(data)
|
||||
|
||||
async def on_connect(self) -> None:
|
||||
"""Subscribe to states and list entities on successful API login."""
|
||||
try:
|
||||
|
@ -509,29 +428,14 @@ class ESPHomeManager:
|
|||
)
|
||||
)
|
||||
|
||||
flags = device_info.voice_assistant_feature_flags_compat(api_version)
|
||||
if flags:
|
||||
if flags & VoiceAssistantFeature.API_AUDIO:
|
||||
entry_data.disconnect_callbacks.add(
|
||||
cli.subscribe_voice_assistant(
|
||||
handle_start=self._handle_pipeline_start,
|
||||
handle_stop=self._handle_pipeline_stop,
|
||||
handle_audio=self._handle_audio,
|
||||
)
|
||||
)
|
||||
else:
|
||||
entry_data.disconnect_callbacks.add(
|
||||
cli.subscribe_voice_assistant(
|
||||
handle_start=self._handle_pipeline_start,
|
||||
handle_stop=self._handle_pipeline_stop,
|
||||
)
|
||||
)
|
||||
if flags & VoiceAssistantFeature.TIMERS:
|
||||
entry_data.disconnect_callbacks.add(
|
||||
async_register_timer_handler(
|
||||
hass, self.device_id, partial(handle_timer_event, cli)
|
||||
)
|
||||
)
|
||||
if device_info.voice_assistant_feature_flags_compat(api_version) and (
|
||||
Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms
|
||||
):
|
||||
# Create assist satellite entity
|
||||
await self.hass.config_entries.async_forward_entry_setups(
|
||||
self.entry, [Platform.ASSIST_SATELLITE]
|
||||
)
|
||||
entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE)
|
||||
|
||||
cli.subscribe_states(entry_data.async_update_state)
|
||||
cli.subscribe_service_calls(self.async_on_service_call)
|
||||
|
|
|
@ -59,6 +59,11 @@
|
|||
}
|
||||
},
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"name": "[%key:component::assist_satellite::entity_component::_::name%]"
|
||||
}
|
||||
},
|
||||
"binary_sensor": {
|
||||
"assist_in_progress": {
|
||||
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
|
||||
|
|
|
@ -1,479 +0,0 @@
|
|||
"""ESPHome voice assistant support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
from typing import cast
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantFeature,
|
||||
VoiceAssistantTimerEventType,
|
||||
)
|
||||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
PipelineStage,
|
||||
WakeWordSettings,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.error import (
|
||||
WakeWordDetectionAborted,
|
||||
WakeWordDetectionError,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entry_data import RuntimeEntryData
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
|
||||
UDP_MAX_PACKET_SIZE = 1024
|
||||
|
||||
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||
VoiceAssistantEventType, PipelineEventType
|
||||
] = EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
|
||||
}
|
||||
)
|
||||
|
||||
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
|
||||
EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class VoiceAssistantPipeline:
|
||||
"""Base abstract pipeline class."""
|
||||
|
||||
started = False
|
||||
stop_requested = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
entry_data: RuntimeEntryData,
|
||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||
handle_finished: Callable[[], None],
|
||||
) -> None:
|
||||
"""Initialize the pipeline."""
|
||||
self.context = Context()
|
||||
self.hass = hass
|
||||
self.entry_data = entry_data
|
||||
assert entry_data.device_info is not None
|
||||
self.device_info = entry_data.device_info
|
||||
|
||||
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self.handle_event = handle_event
|
||||
self.handle_finished = handle_finished
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_task: asyncio.Task | None = None
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""True if the pipeline is started and hasn't been asked to stop."""
|
||||
return self.started and (not self.stop_requested)
|
||||
|
||||
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
||||
"""Iterate over incoming packets."""
|
||||
while data := await self.queue.get():
|
||||
if not self.is_running:
|
||||
break
|
||||
|
||||
yield data
|
||||
|
||||
def _event_callback(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
||||
try:
|
||||
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
||||
except KeyError:
|
||||
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
|
||||
return
|
||||
|
||||
data_to_send = None
|
||||
error = False
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
||||
}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["tts_input"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert event.data is not None
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
path = tts_output["url"]
|
||||
url = async_process_play_media_url(self.hass, path)
|
||||
data_to_send = {"url": url}
|
||||
|
||||
if (
|
||||
self.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
& VoiceAssistantFeature.SPEAKER
|
||||
):
|
||||
media_id = tts_output["media_id"]
|
||||
self._tts_task = self.hass.async_create_background_task(
|
||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
||||
)
|
||||
else:
|
||||
self._tts_done.set()
|
||||
else:
|
||||
# Empty TTS response
|
||||
data_to_send = {}
|
||||
self._tts_done.set()
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||
assert event.data is not None
|
||||
if not event.data["wake_word_output"]:
|
||||
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||
data_to_send = {
|
||||
"code": "no_wake_word",
|
||||
"message": "No wake word detected",
|
||||
}
|
||||
error = True
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"code": event.data["code"],
|
||||
"message": event.data["message"],
|
||||
}
|
||||
error = True
|
||||
|
||||
self.handle_event(event_type, data_to_send)
|
||||
if error:
|
||||
self._tts_done.set()
|
||||
self.handle_finished()
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
device_id: str,
|
||||
conversation_id: str | None,
|
||||
flags: int = 0,
|
||||
audio_settings: VoiceAssistantAudioSettings | None = None,
|
||||
wake_word_phrase: str | None = None,
|
||||
) -> None:
|
||||
"""Run the Voice Assistant pipeline."""
|
||||
if audio_settings is None or audio_settings.volume_multiplier == 0:
|
||||
audio_settings = VoiceAssistantAudioSettings()
|
||||
|
||||
if (
|
||||
self.device_info.voice_assistant_feature_flags_compat(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
& VoiceAssistantFeature.SPEAKER
|
||||
):
|
||||
tts_audio_output = "wav"
|
||||
else:
|
||||
tts_audio_output = "mp3"
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||
start_stage = PipelineStage.WAKE_WORD
|
||||
else:
|
||||
start_stage = PipelineStage.STT
|
||||
try:
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self.context,
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=self._iterate_packets(),
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.device_info.mac_address
|
||||
),
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output=tts_audio_output,
|
||||
start_stage=start_stage,
|
||||
wake_word_settings=WakeWordSettings(timeout=5),
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
noise_suppression_level=audio_settings.noise_suppression_level,
|
||||
auto_gain_dbfs=audio_settings.auto_gain,
|
||||
volume_multiplier=audio_settings.volume_multiplier,
|
||||
is_vad_enabled=bool(flags & VoiceAssistantCommandFlag.USE_VAD),
|
||||
silence_seconds=VadSensitivity.to_seconds(
|
||||
pipeline_select.get_vad_sensitivity(
|
||||
self.hass, DOMAIN, self.device_info.mac_address
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Block until TTS is done sending
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound as e:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
},
|
||||
)
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except WakeWordDetectionAborted:
|
||||
pass # Wake word detection was aborted and `handle_finished` is enough.
|
||||
except WakeWordDetectionError as e:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
self.handle_finished()
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to device via UDP."""
|
||||
# Always send stream start/end events
|
||||
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
||||
|
||||
try:
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != 16000)
|
||||
or (sample_width != 2)
|
||||
or (sample_channels != 1)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected rate/width/channels as 16000/2/1,"
|
||||
" got {sample_rate}/{sample_width}/{sample_channels}}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
audio_bytes_size = len(audio_bytes)
|
||||
|
||||
_LOGGER.debug("Sending %d bytes of audio", audio_bytes_size)
|
||||
|
||||
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
|
||||
sample_offset = 0
|
||||
samples_left = audio_bytes_size // bytes_per_sample
|
||||
|
||||
while (samples_left > 0) and self.is_running:
|
||||
bytes_offset = sample_offset * bytes_per_sample
|
||||
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
|
||||
samples_in_chunk = len(chunk) // bytes_per_sample
|
||||
samples_left -= samples_in_chunk
|
||||
|
||||
self.send_audio_bytes(chunk)
|
||||
await asyncio.sleep(
|
||||
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
|
||||
)
|
||||
|
||||
sample_offset += samples_in_chunk
|
||||
finally:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||
)
|
||||
self._tts_task = None
|
||||
self._tts_done.set()
|
||||
|
||||
def send_audio_bytes(self, data: bytes) -> None:
|
||||
"""Send bytes to the device."""
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the pipeline."""
|
||||
self.queue.put_nowait(b"")
|
||||
|
||||
|
||||
class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline):
|
||||
"""Receive UDP packets and forward them to the voice assistant."""
|
||||
|
||||
transport: asyncio.DatagramTransport | None = None
|
||||
remote_addr: tuple[str, int] | None = None
|
||||
|
||||
async def start_server(self) -> int:
|
||||
"""Start accepting connections."""
|
||||
|
||||
def accept_connection() -> VoiceAssistantUDPPipeline:
|
||||
"""Accept connection."""
|
||||
if self.started:
|
||||
raise RuntimeError("Can only start once")
|
||||
if self.stop_requested:
|
||||
raise RuntimeError("No longer accepting connections")
|
||||
|
||||
self.started = True
|
||||
return self
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
|
||||
sock.bind(("", UDP_PORT))
|
||||
|
||||
await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
accept_connection, sock=sock
|
||||
)
|
||||
|
||||
return cast(int, sock.getsockname()[1])
|
||||
|
||||
@callback
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Store transport for later use."""
|
||||
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
@callback
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Handle incoming UDP packet."""
|
||||
if not self.is_running:
|
||||
return
|
||||
if self.remote_addr is None:
|
||||
self.remote_addr = addr
|
||||
self.queue.put_nowait(data)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
"""Handle when a send or receive operation raises an OSError.
|
||||
|
||||
(Other than BlockingIOError or InterruptedError.)
|
||||
"""
|
||||
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||
self.handle_finished()
|
||||
|
||||
@callback
|
||||
def stop(self) -> None:
|
||||
"""Stop the receiver."""
|
||||
super().stop()
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the receiver."""
|
||||
self.started = False
|
||||
self.stop_requested = True
|
||||
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
def send_audio_bytes(self, data: bytes) -> None:
|
||||
"""Send bytes to the device via UDP."""
|
||||
if self.transport is None:
|
||||
_LOGGER.error("No transport to send audio to")
|
||||
return
|
||||
self.transport.sendto(data, self.remote_addr)
|
||||
|
||||
|
||||
class VoiceAssistantAPIPipeline(VoiceAssistantPipeline):
|
||||
"""Send audio to the voice assistant via the API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
entry_data: RuntimeEntryData,
|
||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||
handle_finished: Callable[[], None],
|
||||
api_client: APIClient,
|
||||
) -> None:
|
||||
"""Initialize the pipeline."""
|
||||
super().__init__(hass, entry_data, handle_event, handle_finished)
|
||||
self.api_client = api_client
|
||||
self.started = True
|
||||
|
||||
def send_audio_bytes(self, data: bytes) -> None:
|
||||
"""Send bytes to the device via the API."""
|
||||
self.api_client.send_voice_assistant_audio(data)
|
||||
|
||||
@callback
|
||||
def receive_audio_bytes(self, data: bytes) -> None:
|
||||
"""Receive audio bytes from the device."""
|
||||
if not self.is_running:
|
||||
return
|
||||
self.queue.put_nowait(data)
|
||||
|
||||
@callback
|
||||
def stop(self) -> None:
|
||||
"""Stop the pipeline."""
|
||||
super().stop()
|
||||
|
||||
self.started = False
|
||||
self.stop_requested = True
|
||||
|
||||
|
||||
def handle_timer_event(
|
||||
api_client: APIClient, event_type: TimerEventType, timer_info: TimerInfo
|
||||
) -> None:
|
||||
"""Handle timer events."""
|
||||
try:
|
||||
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
|
||||
except KeyError:
|
||||
_LOGGER.debug("Received unknown timer event type: %s", event_type)
|
||||
return
|
||||
|
||||
api_client.send_voice_assistant_timer_event(
|
||||
native_event_type,
|
||||
timer_info.id,
|
||||
timer_info.name,
|
||||
timer_info.created_seconds,
|
||||
timer_info.seconds_left,
|
||||
timer_info.is_active,
|
||||
)
|
|
@ -20,6 +20,7 @@ from .devices import VoIPDevices
|
|||
from .voip import HassVoipDatagramProtocol
|
||||
|
||||
PLATFORMS = (
|
||||
Platform.ASSIST_SATELLITE,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
|
|
312
homeassistant/components/voip/assist_satellite.py
Normal file
312
homeassistant/components/voip/assist_satellite.py
Normal file
|
@ -0,0 +1,312 @@
|
|||
"""Assist satellite entity for VoIP integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import IntFlag
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
import wave
|
||||
|
||||
from voip_utils import RtpDatagramProtocol
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.async_ import queue_to_iterable
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DomainData
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||
|
||||
|
||||
class Tones(IntFlag):
|
||||
"""Feedback tones for specific events."""
|
||||
|
||||
LISTENING = 1
|
||||
PROCESSING = 2
|
||||
ERROR = 4
|
||||
|
||||
|
||||
_TONE_FILENAMES: dict[Tones, str] = {
|
||||
Tones.LISTENING: "tone.pcm",
|
||||
Tones.PROCESSING: "processing.pcm",
|
||||
Tones.ERROR: "error.pcm",
|
||||
}
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP Assist satellite entity."""
|
||||
domain_data: DomainData = hass.data[DOMAIN]
|
||||
|
||||
@callback
|
||||
def async_add_device(device: VoIPDevice) -> None:
|
||||
"""Add device."""
|
||||
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
|
||||
|
||||
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||
|
||||
entities: list[VoIPEntity] = [
|
||||
VoipAssistSatellite(hass, device, config_entry)
|
||||
for device in domain_data.devices
|
||||
]
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
|
||||
"""Assist satellite for VoIP devices."""
|
||||
|
||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||
_attr_translation_key = "assist_satellite"
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
voip_device: VoIPDevice,
|
||||
config_entry: ConfigEntry,
|
||||
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
|
||||
) -> None:
|
||||
"""Initialize an Assist satellite."""
|
||||
VoIPEntity.__init__(self, voip_device)
|
||||
AssistSatelliteEntity.__init__(self)
|
||||
RtpDatagramProtocol.__init__(self)
|
||||
|
||||
self.config_entry = config_entry
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._audio_chunk_timeout: float = 2.0
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._pipeline_had_error: bool = False
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_extra_timeout: float = 1.0
|
||||
self._tone_bytes: dict[Tones, bytes] = {}
|
||||
self._tones = tones
|
||||
self._processing_tone_done = asyncio.Event()
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
return self.voip_device.get_pipeline_entity_id(self.hass)
|
||||
|
||||
@property
|
||||
def vad_sensitivity_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
return self.voip_device.get_vad_sensitivity_entity_id(self.hass)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
self.voip_device.protocol = self
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
await super().async_will_remove_from_hass()
|
||||
assert self.voip_device.protocol == self
|
||||
self.voip_device.protocol = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# VoIP
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||
|
||||
try:
|
||||
self._tts_done.clear()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
||||
await self.async_accept_pipeline_from_satellite(
|
||||
audio_stream=queue_to_iterable(
|
||||
self._audio_queue, timeout=self._audio_chunk_timeout
|
||||
),
|
||||
)
|
||||
|
||||
if self._pipeline_had_error:
|
||||
self._pipeline_had_error = False
|
||||
await self._play_tone(Tones.ERROR)
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline cancelled or timed out")
|
||||
self.disconnect()
|
||||
self._clear_audio_queue()
|
||||
finally:
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
"""Ensure audio queue is empty."""
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type == PipelineEventType.STT_END:
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
self._processing_tone_done.clear()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
|
||||
)
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||
self._pipeline_had_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return # not connected
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
# Don't overlap TTS and processing beep
|
||||
await self._processing_tone_done.wait()
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
# Update satellite state
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
|
||||
"""Play a tone as feedback to the user if it's enabled."""
|
||||
if (self._tones & tone) != tone:
|
||||
return # not enabled
|
||||
|
||||
if tone not in self._tone_bytes:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
_TONE_FILENAMES[tone],
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes[tone],
|
||||
silence_before=silence_before,
|
||||
)
|
||||
|
||||
if tone == Tones.PROCESSING:
|
||||
self._processing_tone_done.set()
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
|
@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
|
|||
"""Call when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
|
||||
self.async_on_remove(
|
||||
self.voip_device.async_listen_update(self._is_active_changed)
|
||||
)
|
||||
|
||||
@callback
|
||||
def _is_active_changed(self, device: VoIPDevice) -> None:
|
||||
"""Call when active state changed."""
|
||||
self._attr_is_on = self._device.is_active
|
||||
self._attr_is_on = self.voip_device.is_active
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from voip_utils import CallInfo
|
||||
from voip_utils import CallInfo, VoipDatagramProtocol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
|
@ -22,6 +22,7 @@ class VoIPDevice:
|
|||
device_id: str
|
||||
is_active: bool = False
|
||||
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
|
||||
protocol: VoipDatagramProtocol | None = None
|
||||
|
||||
@callback
|
||||
def set_is_active(self, active: bool) -> None:
|
||||
|
@ -56,6 +57,18 @@ class VoIPDevice:
|
|||
|
||||
return False
|
||||
|
||||
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for pipeline select."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
|
||||
|
||||
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for VAD sensitivity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
|
||||
)
|
||||
|
||||
|
||||
class VoIPDevices:
|
||||
"""Class to store devices."""
|
||||
|
|
|
@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
|
|||
_attr_has_entity_name = True
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, device: VoIPDevice) -> None:
|
||||
def __init__(self, voip_device: VoIPDevice) -> None:
|
||||
"""Initialize VoIP entity."""
|
||||
self._device = device
|
||||
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
|
||||
self.voip_device = voip_device
|
||||
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, device.voip_id)},
|
||||
identifiers={(DOMAIN, voip_device.voip_id)},
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Voice over IP",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline"],
|
||||
"dependencies": ["assist_pipeline", "assist_satellite"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
|
|
|
@ -10,6 +10,16 @@
|
|||
}
|
||||
},
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"binary_sensor": {
|
||||
"call_in_progress": {
|
||||
"name": "Call in progress"
|
||||
|
|
|
@ -3,15 +3,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import wave
|
||||
|
||||
from voip_utils import (
|
||||
CallInfo,
|
||||
|
@ -21,33 +17,19 @@ from voip_utils import (
|
|||
VoipDatagramProtocol,
|
||||
)
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
async_get_pipeline,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.audio_enhancer import (
|
||||
AudioEnhancer,
|
||||
MicroVadSpeexEnhancer,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VadSensitivity,
|
||||
VoiceCommandSegmenter,
|
||||
)
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .devices import VoIPDevice, VoIPDevices
|
||||
from .devices import VoIPDevices
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,11 +42,8 @@ def make_protocol(
|
|||
) -> VoipDatagramProtocol:
|
||||
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
||||
voip_device = devices.async_get_or_create(call_info)
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
|
||||
try:
|
||||
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||
except PipelineNotFound:
|
||||
|
@ -83,22 +62,18 @@ def make_protocol(
|
|||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
if (protocol := voip_device.protocol) is None:
|
||||
raise ValueError("VoIP satellite not found")
|
||||
|
||||
# Pipeline is properly configured
|
||||
return PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(user_id=devices.config_entry.data["user"]),
|
||||
opus_payload_type=call_info.opus_payload_type,
|
||||
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
|
||||
protocol.rtcp_state = rtcp_state
|
||||
if protocol.rtcp_state is not None:
|
||||
# Automatically disconnect when BYE is received over RTCP
|
||||
protocol.rtcp_state.bye_callback = protocol.disconnect
|
||||
|
||||
return protocol
|
||||
|
||||
|
||||
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||
|
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
|||
await self._closed_event.wait()
|
||||
|
||||
|
||||
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
language: str,
|
||||
voip_device: VoIPDevice,
|
||||
context: Context,
|
||||
opus_payload_type: int,
|
||||
pipeline_timeout: float = 30.0,
|
||||
audio_timeout: float = 2.0,
|
||||
buffered_chunks_before_speech: int = 100,
|
||||
listening_tone_enabled: bool = True,
|
||||
processing_tone_enabled: bool = True,
|
||||
error_tone_enabled: bool = True,
|
||||
tone_delay: float = 0.2,
|
||||
tts_extra_timeout: float = 1.0,
|
||||
silence_seconds: float = 1.0,
|
||||
rtcp_state: RtcpState | None = None,
|
||||
) -> None:
|
||||
"""Set up pipeline RTP server."""
|
||||
super().__init__(
|
||||
rate=RATE,
|
||||
width=WIDTH,
|
||||
channels=CHANNELS,
|
||||
opus_payload_type=opus_payload_type,
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
self.hass = hass
|
||||
self.language = language
|
||||
self.voip_device = voip_device
|
||||
self.pipeline: Pipeline | None = None
|
||||
self.pipeline_timeout = pipeline_timeout
|
||||
self.audio_timeout = audio_timeout
|
||||
self.buffered_chunks_before_speech = buffered_chunks_before_speech
|
||||
self.listening_tone_enabled = listening_tone_enabled
|
||||
self.processing_tone_enabled = processing_tone_enabled
|
||||
self.error_tone_enabled = error_tone_enabled
|
||||
self.tone_delay = tone_delay
|
||||
self.tts_extra_timeout = tts_extra_timeout
|
||||
self.silence_seconds = silence_seconds
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._context = context
|
||||
self._conversation_id: str | None = None
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._tts_done = asyncio.Event()
|
||||
self._session_id: str | None = None
|
||||
self._tone_bytes: bytes | None = None
|
||||
self._processing_bytes: bytes | None = None
|
||||
self._error_bytes: bytes | None = None
|
||||
self._pipeline_error: bool = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Server is ready."""
|
||||
super().connection_made(transport)
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Handle connection is lost or closed."""
|
||||
super().connection_lost(exc)
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.hass.async_create_background_task(
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
if self._session_id is None:
|
||||
self._session_id = ulid_now()
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
if self.listening_tone_enabled:
|
||||
await self._play_listening_tone()
|
||||
|
||||
try:
|
||||
# Wait for speech before starting pipeline
|
||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||
audio_enhancer = MicroVadSpeexEnhancer(0, 0, True)
|
||||
chunk_buffer: deque[bytes] = deque(
|
||||
maxlen=self.buffered_chunks_before_speech,
|
||||
)
|
||||
speech_detected = await self._wait_for_speech(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
)
|
||||
if not speech_detected:
|
||||
_LOGGER.debug("No speech detected")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
self._tts_done.clear()
|
||||
|
||||
async def stt_stream():
|
||||
try:
|
||||
async for chunk in self._segment_audio(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if self.processing_tone_enabled:
|
||||
await self._play_processing_tone()
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Audio timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
async with asyncio.timeout(self.pipeline_timeout):
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream(),
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.voip_device.voip_id
|
||||
),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=self.voip_device.device_id,
|
||||
tts_audio_output="wav",
|
||||
)
|
||||
|
||||
if self._pipeline_error:
|
||||
self._pipeline_error = False
|
||||
if self.error_tone_enabled:
|
||||
await self._play_error_tone()
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
async def _wait_for_speech(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: MutableSequence[bytes],
|
||||
):
|
||||
"""Buffer audio chunks until speech is detected.
|
||||
|
||||
Returns True if speech was detected, False otherwise.
|
||||
"""
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
chunk_buffer.append(chunk)
|
||||
|
||||
segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
)
|
||||
if segmenter.in_command:
|
||||
# Buffer until command starts
|
||||
if len(vad_buffer) > 0:
|
||||
chunk_buffer.append(vad_buffer.bytes())
|
||||
|
||||
return True
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
return False
|
||||
|
||||
async def _segment_audio(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: Sequence[bytes],
|
||||
) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks until voice command has finished."""
|
||||
# Buffered chunks first
|
||||
for buffered_chunk in chunk_buffer:
|
||||
yield buffered_chunk
|
||||
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def _event_callback(self, event: PipelineEvent):
|
||||
if not event.data:
|
||||
return
|
||||
|
||||
if event.type == PipelineEventType.INTENT_END:
|
||||
# Capture conversation id
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS
|
||||
self._pipeline_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_listening_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is listening."""
|
||||
if self._tone_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"tone.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes,
|
||||
silence_before=self.tone_delay,
|
||||
)
|
||||
|
||||
async def _play_processing_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is processing the voice command."""
|
||||
if self._processing_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._processing_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"processing.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._processing_bytes)
|
||||
|
||||
async def _play_error_tone(self) -> None:
|
||||
"""Play a tone to indicate a pipeline error occurred."""
|
||||
if self._error_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._error_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"error.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._error_bytes)
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
||||
|
||||
|
||||
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
||||
"""Plays a pre-recorded message on a loop."""
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
|||
|
||||
AIR_QUALITY = "air_quality"
|
||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||
ASSIST_SATELLITE = "assist_satellite"
|
||||
BINARY_SENSOR = "binary_sensor"
|
||||
BUTTON = "button"
|
||||
CALENDAR = "calendar"
|
||||
|
|
|
@ -5,22 +5,28 @@ from __future__ import annotations
|
|||
from asyncio import (
|
||||
AbstractEventLoop,
|
||||
Future,
|
||||
Queue,
|
||||
Semaphore,
|
||||
Task,
|
||||
TimerHandle,
|
||||
gather,
|
||||
get_running_loop,
|
||||
timeout as async_timeout,
|
||||
)
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
|
||||
|
||||
_DataT = TypeVar("_DataT", default=Any)
|
||||
|
||||
|
||||
def create_eager_task[_T](
|
||||
coro: Coroutine[Any, Any, _T],
|
||||
|
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
|
|||
"""Return a list of scheduled TimerHandles."""
|
||||
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
|
||||
return handles
|
||||
|
||||
|
||||
async def queue_to_iterable(
|
||||
queue: Queue[_DataT], timeout: float | None = None
|
||||
) -> AsyncIterable[_DataT]:
|
||||
"""Stream items from a queue until None with an optional timeout per item."""
|
||||
if timeout is None:
|
||||
while (item := await queue.get()) is not None:
|
||||
yield item
|
||||
else:
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
||||
while item is not None:
|
||||
yield item
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
|
3
tests/components/assist_satellite/__init__.py
Normal file
3
tests/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Tests for Assist Satellite."""
|
||||
|
||||
ENTITY_ID = "assist_satellite.test_entity"
|
116
tests/components/assist_satellite/conftest.py
Normal file
116
tests/components/assist_satellite/conftest.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
"""Test helpers for Assist Satellite."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
DOMAIN as AS_DOMAIN,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityFeature,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
from tests.components.tts.conftest import (
|
||||
mock_tts_cache_dir_fixture_autouse, # noqa: F401
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test_satellite"
|
||||
|
||||
|
||||
class MockAssistSatellite(AssistSatelliteEntity):
|
||||
"""Mock Assist Satellite Entity."""
|
||||
|
||||
_attr_name = "Test Entity"
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
self.events = []
|
||||
self.announcements = []
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
self.events.append(event)
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
"""Announce media on a device."""
|
||||
self.announcements.append((message, media_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity() -> MockAssistSatellite:
|
||||
"""Mock Assist Satellite Entity."""
|
||||
return MockAssistSatellite()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Mock config entry."""
|
||||
entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
) -> None:
|
||||
"""Initialize components."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test satellite platform via config entry."""
|
||||
async_add_entities([entity])
|
||||
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
149
tests/components/assist_satellite/test_entity.py
Normal file
149
tests/components/assist_satellite/test_entity.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
"""Test the Assist Satellite entity."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
async_update_pipeline,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteState
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
|
||||
async def test_entity_state(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test entity state represent events."""
|
||||
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
context = Context()
|
||||
audio_stream = object()
|
||||
|
||||
entity.async_set_context(context)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||
) as mock_start_pipeline:
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
|
||||
assert mock_start_pipeline.called
|
||||
kwargs = mock_start_pipeline.call_args[1]
|
||||
assert kwargs["context"] is context
|
||||
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
|
||||
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
assert kwargs["stt_stream"] is audio_stream
|
||||
assert kwargs["pipeline_id"] is None
|
||||
assert kwargs["device_id"] is None
|
||||
assert kwargs["tts_audio_output"] == "wav"
|
||||
assert kwargs["wake_word_phrase"] is None
|
||||
assert kwargs["audio_settings"] == AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||
)
|
||||
assert kwargs["start_stage"] == PipelineStage.STT
|
||||
assert kwargs["end_stage"] == PipelineStage.TTS
|
||||
|
||||
for event_type, expected_state in (
|
||||
(PipelineEventType.RUN_START, STATE_UNKNOWN),
|
||||
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
|
||||
):
|
||||
kwargs["event_callback"](PipelineEvent(event_type, {}))
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state.state == expected_state, event_type
|
||||
|
||||
entity.tts_response_finished()
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("service_data", "expected_params"),
|
||||
[
|
||||
(
|
||||
{"message": "Hello"},
|
||||
("Hello", "https://www.home-assistant.io/resolved.mp3"),
|
||||
),
|
||||
(
|
||||
{
|
||||
"message": "Hello",
|
||||
"media_id": "http://example.com/bla.mp3",
|
||||
},
|
||||
("Hello", "http://example.com/bla.mp3"),
|
||||
),
|
||||
(
|
||||
{"media_id": "http://example.com/bla.mp3"},
|
||||
("", "http://example.com/bla.mp3"),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_announce(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
service_data: dict,
|
||||
expected_params: tuple[str, str],
|
||||
) -> None:
|
||||
"""Test announcing on a device."""
|
||||
await async_update_pipeline(
|
||||
hass,
|
||||
async_get_pipeline(hass),
|
||||
tts_engine="tts.mock_entity",
|
||||
tts_language="en",
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"assist_satellite",
|
||||
"announce",
|
||||
service_data,
|
||||
target={"entity_id": "assist_satellite.test_entity"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert entity.announcements[0] == expected_params
|
192
tests/components/assist_satellite/test_websocket_api.py
Normal file
192
tests/components/assist_satellite/test_websocket_api.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
"""Test WebSocket API."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineStage
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
from tests.common import MockUser
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
start_stage=PipelineStage.STT,
|
||||
wake_word_phrase="ok, nabu",
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_requires_on_device_wake_word(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word fails if detection happens in HA."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
# Emulate wake word processing in Home Assistant
|
||||
start_stage=PipelineStage.WAKE_WORD,
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "Only on-device wake words currently supported",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_requires_wake_word_phrase(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word fails if detection happens in HA."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
start_stage=PipelineStage.STT,
|
||||
# We are not passing wake word phrase
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "No wake word phrase provided",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_require_admin(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
# Remove admin permission and verify we're not allowed
|
||||
hass_admin_user.groups = []
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "unauthorized",
|
||||
"message": "Unauthorized",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_invalid_satellite(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": "assist_satellite.invalid",
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Entity not found",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_twice(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "Wake word interception already in progress",
|
||||
}
|
|
@ -20,7 +20,6 @@ from aioesphomeapi import (
|
|||
ReconnectLogic,
|
||||
UserService,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantFeature,
|
||||
)
|
||||
import pytest
|
||||
|
@ -34,11 +33,6 @@ from homeassistant.components.esphome.const import (
|
|||
DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.components.esphome.entry_data import RuntimeEntryData
|
||||
from homeassistant.components.esphome.voice_assistant import (
|
||||
VoiceAssistantAPIPipeline,
|
||||
VoiceAssistantUDPPipeline,
|
||||
)
|
||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -625,57 +619,3 @@ async def mock_esphome_device(
|
|||
)
|
||||
|
||||
return _mock_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voice_assistant_api_pipeline() -> VoiceAssistantAPIPipeline:
|
||||
"""Return the API Pipeline factory."""
|
||||
mock_pipeline = Mock(spec=VoiceAssistantAPIPipeline)
|
||||
|
||||
def mock_constructor(
|
||||
hass: HomeAssistant,
|
||||
entry_data: RuntimeEntryData,
|
||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||
handle_finished: Callable[[], None],
|
||||
api_client: APIClient,
|
||||
):
|
||||
"""Fake the constructor."""
|
||||
mock_pipeline.hass = hass
|
||||
mock_pipeline.entry_data = entry_data
|
||||
mock_pipeline.handle_event = handle_event
|
||||
mock_pipeline.handle_finished = handle_finished
|
||||
mock_pipeline.api_client = api_client
|
||||
return mock_pipeline
|
||||
|
||||
mock_pipeline.side_effect = mock_constructor
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantAPIPipeline",
|
||||
new=mock_pipeline,
|
||||
):
|
||||
yield mock_pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voice_assistant_udp_pipeline() -> VoiceAssistantUDPPipeline:
|
||||
"""Return the API Pipeline factory."""
|
||||
mock_pipeline = Mock(spec=VoiceAssistantUDPPipeline)
|
||||
|
||||
def mock_constructor(
|
||||
hass: HomeAssistant,
|
||||
entry_data: RuntimeEntryData,
|
||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||
handle_finished: Callable[[], None],
|
||||
):
|
||||
"""Fake the constructor."""
|
||||
mock_pipeline.hass = hass
|
||||
mock_pipeline.entry_data = entry_data
|
||||
mock_pipeline.handle_event = handle_event
|
||||
mock_pipeline.handle_finished = handle_finished
|
||||
return mock_pipeline
|
||||
|
||||
mock_pipeline.side_effect = mock_constructor
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPPipeline",
|
||||
new=mock_pipeline,
|
||||
):
|
||||
yield mock_pipeline
|
||||
|
|
965
tests/components/esphome/test_assist_satellite.py
Normal file
965
tests/components/esphome/test_assist_satellite.py
Normal file
|
@ -0,0 +1,965 @@
|
|||
"""Test ESPHome voice assistant server."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
import io
|
||||
import socket
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
EntityInfo,
|
||||
EntityState,
|
||||
UserService,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantFeature,
|
||||
VoiceAssistantTimerEventType,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import assist_satellite
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityFeature,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.components.esphome import DOMAIN
|
||||
from homeassistant.components.esphome.assist_satellite import (
|
||||
EsphomeAssistSatellite,
|
||||
VoiceAssistantUDPServer,
|
||||
)
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er, intent as intent_helper
|
||||
import homeassistant.helpers.device_registry as dr
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .conftest import MockESPHomeDevice
|
||||
|
||||
|
||||
def get_satellite_entity(
|
||||
hass: HomeAssistant, mac_address: str
|
||||
) -> EsphomeAssistSatellite | None:
|
||||
"""Get the satellite entity for a device."""
|
||||
ent_reg = er.async_get(hass)
|
||||
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.ASSIST_SATELLITE, DOMAIN, f"{mac_address}-assist_satellite"
|
||||
)
|
||||
if satellite_entity_id is None:
|
||||
return None
|
||||
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||
assist_satellite.DOMAIN
|
||||
]
|
||||
if (entity := component.get_entity(satellite_entity_id)) is not None:
|
||||
assert isinstance(entity, EsphomeAssistSatellite)
|
||||
return entity
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_wav() -> bytes:
|
||||
"""Return test WAV audio."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(b"test-wav")
|
||||
|
||||
return wav_io.getvalue()
|
||||
|
||||
|
||||
async def test_no_satellite_without_voice_assistant(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that an assist satellite entity is not created if a voice assistant is not present."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# No satellite entity should be created
|
||||
assert get_satellite_entity(hass, mock_device.device_info.mac_address) is None
|
||||
|
||||
|
||||
async def test_pipeline_api_audio(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test a complete pipeline run with API audio (over the TCP connection)."""
|
||||
conversation_id = "test-conversation-id"
|
||||
media_url = "http://test.url"
|
||||
media_id = "test-media-id"
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.SPEAKER
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
dev = device_registry.async_get_device(
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||
)
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
# Block TTS streaming until we're ready.
|
||||
# This makes it easier to verify the order of pipeline events.
|
||||
stream_tts_audio_ready = asyncio.Event()
|
||||
original_stream_tts_audio = satellite._stream_tts_audio
|
||||
|
||||
async def _stream_tts_audio(*args, **kwargs):
|
||||
await stream_tts_audio_ready.wait()
|
||||
await original_stream_tts_audio(*args, **kwargs)
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
assert device_id == dev.id
|
||||
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
|
||||
chunks = [chunk async for chunk in stt_stream]
|
||||
|
||||
# Verify test API audio
|
||||
assert chunks == [b"test-mic"]
|
||||
|
||||
event_callback = kwargs["event_callback"]
|
||||
|
||||
# Test unknown event type
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type="unknown-event",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_event.assert_not_called()
|
||||
|
||||
# Test error event
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.ERROR,
|
||||
data={"code": "test-error-code", "message": "test-error-message"},
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{"code": "test-error-code", "message": "test-error-message"},
|
||||
)
|
||||
|
||||
# Wake word
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.WAKE_WORD_START,
|
||||
data={
|
||||
"entity_id": "test-wake-word-entity-id",
|
||||
"metadata": {},
|
||||
"timeout": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START,
|
||||
{},
|
||||
)
|
||||
|
||||
# Test no wake word detected
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {}}
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{"code": "no_wake_word", "message": "No wake word detected"},
|
||||
)
|
||||
|
||||
# Correct wake word detection
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.WAKE_WORD_END,
|
||||
data={"wake_word_output": {"wake_word_phrase": "test-wake-word"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END,
|
||||
{},
|
||||
)
|
||||
|
||||
# STT
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_START,
|
||||
data={"engine": "test-stt-engine", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START,
|
||||
{},
|
||||
)
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "test-stt-text"}},
|
||||
)
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END,
|
||||
{"text": "test-stt-text"},
|
||||
)
|
||||
|
||||
# Intent
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test-intent-engine",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "test-intent-text",
|
||||
"conversation_id": conversation_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START,
|
||||
{},
|
||||
)
|
||||
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_END,
|
||||
data={"intent_output": {"conversation_id": conversation_id}},
|
||||
)
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
|
||||
{"conversation_id": conversation_id},
|
||||
)
|
||||
|
||||
# TTS
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test-stt-engine",
|
||||
"language": hass.config.language,
|
||||
"voice": "test-voice",
|
||||
"tts_input": "test-tts-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START,
|
||||
{"text": "test-tts-text"},
|
||||
)
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Should return mock_wav audio
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||
)
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
||||
{"url": media_url},
|
||||
)
|
||||
|
||||
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END,
|
||||
{},
|
||||
)
|
||||
|
||||
# Allow TTS streaming to proceed
|
||||
stream_tts_audio_ready.set()
|
||||
|
||||
pipeline_finished = asyncio.Event()
|
||||
original_handle_pipeline_finished = satellite.handle_pipeline_finished
|
||||
|
||||
def handle_pipeline_finished():
|
||||
original_handle_pipeline_finished()
|
||||
pipeline_finished.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("wav", mock_wav)
|
||||
|
||||
tts_finished = asyncio.Event()
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
tts_finished.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
# Should be cleared at pipeline start
|
||||
satellite._audio_queue.put_nowait(b"leftover-data")
|
||||
|
||||
# Should be cancelled at pipeline start
|
||||
mock_tts_streaming_task = Mock()
|
||||
satellite._tts_streaming_task = mock_tts_streaming_task
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await satellite.handle_pipeline_start(
|
||||
conversation_id=conversation_id,
|
||||
flags=VoiceAssistantCommandFlag.USE_WAKE_WORD,
|
||||
audio_settings=VoiceAssistantAudioSettings(),
|
||||
wake_word_phrase="",
|
||||
)
|
||||
mock_tts_streaming_task.cancel.assert_called_once()
|
||||
await satellite.handle_audio(b"test-mic")
|
||||
await satellite.handle_pipeline_stop()
|
||||
await pipeline_finished.wait()
|
||||
|
||||
await tts_finished.wait()
|
||||
|
||||
# Verify TTS streaming events.
|
||||
# These are definitely the last two events because we blocked TTS streaming
|
||||
# until after RUN_END above.
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||
{},
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||
{},
|
||||
)
|
||||
|
||||
# Verify TTS WAV audio chunk came through
|
||||
mock_client.send_voice_assistant_audio.assert_called_once_with(b"test-wav")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_pipeline_udp_audio(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test a complete pipeline run with legacy UDP audio.
|
||||
|
||||
This test is not as comprehensive as test_pipeline_api_audio since we're
|
||||
mainly focused on the UDP server.
|
||||
"""
|
||||
conversation_id = "test-conversation-id"
|
||||
media_url = "http://test.url"
|
||||
media_id = "test-media-id"
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.SPEAKER
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
mic_audio_event = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
|
||||
chunks = []
|
||||
async for chunk in stt_stream:
|
||||
chunks.append(chunk)
|
||||
mic_audio_event.set()
|
||||
|
||||
# Verify test UDP audio
|
||||
assert chunks == [b"test-mic"]
|
||||
|
||||
event_callback = kwargs["event_callback"]
|
||||
|
||||
# STT
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_START,
|
||||
data={"engine": "test-stt-engine", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "test-stt-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Intent
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test-intent-engine",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "test-intent-text",
|
||||
"conversation_id": conversation_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_END,
|
||||
data={"intent_output": {"conversation_id": conversation_id}},
|
||||
)
|
||||
)
|
||||
|
||||
# TTS
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test-stt-engine",
|
||||
"language": hass.config.language,
|
||||
"voice": "test-voice",
|
||||
"tts_input": "test-tts-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Should return mock_wav audio
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||
|
||||
pipeline_finished = asyncio.Event()
|
||||
original_handle_pipeline_finished = satellite.handle_pipeline_finished
|
||||
|
||||
def handle_pipeline_finished():
|
||||
original_handle_pipeline_finished()
|
||||
pipeline_finished.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("wav", mock_wav)
|
||||
|
||||
tts_finished = asyncio.Event()
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
tts_finished.set()
|
||||
|
||||
class TestProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self) -> None:
|
||||
self.transport = None
|
||||
self.data_received: list[bytes] = []
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def datagram_received(self, data: bytes, addr):
|
||||
self.data_received.append(data)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
port = await satellite.handle_pipeline_start(
|
||||
conversation_id=conversation_id,
|
||||
flags=VoiceAssistantCommandFlag(0), # stt
|
||||
audio_settings=VoiceAssistantAudioSettings(),
|
||||
wake_word_phrase="",
|
||||
)
|
||||
assert (port is not None) and (port > 0)
|
||||
|
||||
(
|
||||
transport,
|
||||
protocol,
|
||||
) = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
TestProtocol, remote_addr=("127.0.0.1", port)
|
||||
)
|
||||
assert isinstance(protocol, TestProtocol)
|
||||
|
||||
# Send audio over UDP
|
||||
transport.sendto(b"test-mic")
|
||||
|
||||
# Wait for audio chunk to be delivered
|
||||
await mic_audio_event.wait()
|
||||
|
||||
await satellite.handle_pipeline_stop()
|
||||
await pipeline_finished.wait()
|
||||
|
||||
await tts_finished.wait()
|
||||
|
||||
# Verify TTS audio (from UDP)
|
||||
assert protocol.data_received == [b"test-wav"]
|
||||
|
||||
# Check that UDP server was stopped
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind(("", port)) # will fail if UDP server is still running
|
||||
sock.close()
|
||||
|
||||
|
||||
async def test_udp_errors() -> None:
|
||||
"""Test UDP protocol error conditions."""
|
||||
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
protocol = VoiceAssistantUDPServer(audio_queue)
|
||||
|
||||
protocol.datagram_received(b"test", ("", 0))
|
||||
assert audio_queue.qsize() == 1
|
||||
assert (await audio_queue.get()) == b"test"
|
||||
|
||||
# None will stop the pipeline
|
||||
protocol.error_received(RuntimeError())
|
||||
assert audio_queue.qsize() == 1
|
||||
assert (await audio_queue.get()) is None
|
||||
|
||||
# No transport
|
||||
assert protocol.transport is None
|
||||
protocol.send_audio_bytes(b"test")
|
||||
|
||||
# No remote address
|
||||
protocol.transport = Mock()
|
||||
protocol.remote_addr = None
|
||||
protocol.send_audio_bytes(b"test")
|
||||
protocol.transport.sendto.assert_not_called()
|
||||
|
||||
|
||||
async def test_timer_events(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that injecting timer events results in the correct api client calls."""
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.TIMERS
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
dev = device_registry.async_get_device(
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||
)
|
||||
|
||||
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
|
||||
await intent_helper.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent_helper.INTENT_START_TIMER,
|
||||
{
|
||||
"name": {"value": "test timer"},
|
||||
"hours": {"value": 1},
|
||||
"minutes": {"value": 2},
|
||||
"seconds": {"value": 3},
|
||||
},
|
||||
device_id=dev.id,
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
|
||||
ANY,
|
||||
"test timer",
|
||||
total_seconds,
|
||||
total_seconds,
|
||||
True,
|
||||
)
|
||||
|
||||
# Increase timer beyond original time and check total_seconds has increased
|
||||
mock_client.send_voice_assistant_timer_event.reset_mock()
|
||||
|
||||
total_seconds += 5 * 60
|
||||
await intent_helper.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent_helper.INTENT_INCREASE_TIMER,
|
||||
{
|
||||
"name": {"value": "test timer"},
|
||||
"minutes": {"value": 5},
|
||||
},
|
||||
device_id=dev.id,
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
|
||||
ANY,
|
||||
"test timer",
|
||||
total_seconds,
|
||||
ANY,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
async def test_unknown_timer_event(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that unknown (new) timer event types do not result in api calls."""
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.TIMERS
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert mock_device.entry.unique_id is not None
|
||||
dev = device_registry.async_get_device(
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||
)
|
||||
assert dev is not None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.assist_satellite._TIMER_EVENT_TYPES.from_hass",
|
||||
side_effect=KeyError,
|
||||
):
|
||||
await intent_helper.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent_helper.INTENT_START_TIMER,
|
||||
{
|
||||
"name": {"value": "test timer"},
|
||||
"hours": {"value": 1},
|
||||
"minutes": {"value": 2},
|
||||
"seconds": {"value": 3},
|
||||
},
|
||||
device_id=dev.id,
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_timer_event.assert_not_called()
|
||||
|
||||
|
||||
async def test_streaming_tts_errors(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test error conditions for _stream_tts_audio function."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
# Should not stream if not running
|
||||
satellite._is_running = False
|
||||
await satellite._stream_tts_audio("test-media-id")
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
satellite._is_running = True
|
||||
|
||||
# Should only stream WAV
|
||||
async def get_mp3(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("mp3", b"")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_mp3
|
||||
):
|
||||
await satellite._stream_tts_audio("test-media-id")
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
|
||||
# Needs to be the correct sample rate, etc.
|
||||
async def get_bad_wav(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(48000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(b"test-wav")
|
||||
|
||||
return ("wav", wav_io.getvalue())
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav
|
||||
):
|
||||
await satellite._stream_tts_audio("test-media-id")
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
|
||||
# Check that TTS_STREAM_* events still get sent after cancel
|
||||
media_fetched = asyncio.Event()
|
||||
|
||||
async def get_slow_wav(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
media_fetched.set()
|
||||
await asyncio.sleep(1)
|
||||
return ("wav", mock_wav)
|
||||
|
||||
mock_client.send_voice_assistant_event.reset_mock()
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav
|
||||
):
|
||||
task = asyncio.create_task(satellite._stream_tts_audio("test-media-id"))
|
||||
async with asyncio.timeout(1):
|
||||
# Wait for media to be fetched
|
||||
await media_fetched.wait()
|
||||
|
||||
# Cancel task
|
||||
task.cancel()
|
||||
await task
|
||||
|
||||
# No audio should have gone out
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
|
||||
|
||||
# The TTS_STREAM_* events should have gone out
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||
{},
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
async def test_announce_supported_features(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that the announce supported feature is set by flags."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
assert not (satellite.supported_features & AssistSatelliteEntityFeature.ANNOUNCE)
|
||||
|
||||
|
||||
async def test_announce_message(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test announcement with message."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.SPEAKER
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def wait_voice_assistant_announce(media_id: str, text: str):
|
||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||
assert text == "test-text"
|
||||
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
mock_client,
|
||||
"wait_voice_assistant_announce",
|
||||
new=wait_voice_assistant_announce,
|
||||
),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
"announce",
|
||||
{"entity_id": satellite.entity_id, "message": "test-text"},
|
||||
blocking=True,
|
||||
)
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_announce_media_id(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test announcement with media id."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.SPEAKER
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def wait_voice_assistant_announce(media_id: str, text: str):
|
||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
mock_client,
|
||||
"wait_voice_assistant_announce",
|
||||
new=wait_voice_assistant_announce,
|
||||
),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
"announce",
|
||||
{
|
||||
"entity_id": satellite.entity_id,
|
||||
"media_id": "https://www.home-assistant.io/resolved.mp3",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await done.wait()
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from unittest.mock import AsyncMock, call, patch
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
|
@ -17,7 +17,6 @@ from aioesphomeapi import (
|
|||
UserService,
|
||||
UserServiceArg,
|
||||
UserServiceArgType,
|
||||
VoiceAssistantFeature,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
@ -29,10 +28,6 @@ from homeassistant.components.esphome.const import (
|
|||
DOMAIN,
|
||||
STABLE_BLE_VERSION_STR,
|
||||
)
|
||||
from homeassistant.components.esphome.voice_assistant import (
|
||||
VoiceAssistantAPIPipeline,
|
||||
VoiceAssistantUDPPipeline,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_PASSWORD,
|
||||
|
@ -44,7 +39,7 @@ from homeassistant.data_entry_flow import FlowResultType
|
|||
from homeassistant.helpers import device_registry as dr, issue_registry as ir
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .conftest import _ONE_SECOND, MockESPHomeDevice
|
||||
from .conftest import MockESPHomeDevice
|
||||
|
||||
from tests.common import MockConfigEntry, async_capture_events, async_mock_service
|
||||
|
||||
|
@ -1214,102 +1209,3 @@ async def test_entry_missing_unique_id(
|
|||
await mock_esphome_device(mock_client=mock_client, mock_storage=True)
|
||||
await hass.async_block_till_done()
|
||||
assert entry.unique_id == "11:22:33:44:55:aa"
|
||||
|
||||
|
||||
async def test_manager_voice_assistant_handlers_api(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the handlers are correctly executed in manager.py."""
|
||||
|
||||
device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
},
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.manager.VoiceAssistantAPIPipeline",
|
||||
new=mock_voice_assistant_api_pipeline,
|
||||
),
|
||||
):
|
||||
port: int | None = await device.mock_voice_assistant_handle_start(
|
||||
"", 0, None, None
|
||||
)
|
||||
|
||||
assert port == 0
|
||||
|
||||
port: int | None = await device.mock_voice_assistant_handle_start(
|
||||
"", 0, None, None
|
||||
)
|
||||
|
||||
assert "Previous Voice assistant pipeline was not stopped" in caplog.text
|
||||
|
||||
await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND))
|
||||
|
||||
mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_called_with(
|
||||
bytes(_ONE_SECOND)
|
||||
)
|
||||
|
||||
mock_voice_assistant_api_pipeline.receive_audio_bytes.reset_mock()
|
||||
|
||||
await device.mock_voice_assistant_handle_stop()
|
||||
mock_voice_assistant_api_pipeline.handle_finished()
|
||||
|
||||
await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND))
|
||||
|
||||
mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_not_called()
|
||||
|
||||
|
||||
async def test_manager_voice_assistant_handlers_udp(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
mock_voice_assistant_udp_pipeline: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test the handlers are correctly executed in manager.py."""
|
||||
|
||||
device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
},
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.manager.VoiceAssistantUDPPipeline",
|
||||
new=mock_voice_assistant_udp_pipeline,
|
||||
),
|
||||
):
|
||||
await device.mock_voice_assistant_handle_start("", 0, None, None)
|
||||
|
||||
mock_voice_assistant_udp_pipeline.run_pipeline.assert_called()
|
||||
|
||||
await device.mock_voice_assistant_handle_stop()
|
||||
mock_voice_assistant_udp_pipeline.handle_finished()
|
||||
|
||||
mock_voice_assistant_udp_pipeline.stop.assert_called()
|
||||
mock_voice_assistant_udp_pipeline.close.assert_called()
|
||||
|
|
|
@ -1,964 +0,0 @@
|
|||
"""Test ESPHome voice assistant server."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
import io
|
||||
import socket
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
EntityInfo,
|
||||
EntityState,
|
||||
UserService,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantFeature,
|
||||
VoiceAssistantTimerEventType,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.error import (
|
||||
PipelineNotFound,
|
||||
WakeWordDetectionAborted,
|
||||
WakeWordDetectionError,
|
||||
)
|
||||
from homeassistant.components.esphome import DomainData
|
||||
from homeassistant.components.esphome.voice_assistant import (
|
||||
VoiceAssistantAPIPipeline,
|
||||
VoiceAssistantUDPPipeline,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent as intent_helper
|
||||
import homeassistant.helpers.device_registry as dr
|
||||
|
||||
from .conftest import _ONE_SECOND, MockESPHomeDevice
|
||||
|
||||
_TEST_INPUT_TEXT = "This is an input test"
|
||||
_TEST_OUTPUT_TEXT = "This is an output test"
|
||||
_TEST_OUTPUT_URL = "output.mp3"
|
||||
_TEST_MEDIA_ID = "12345"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_udp_pipeline(
|
||||
hass: HomeAssistant,
|
||||
) -> VoiceAssistantUDPPipeline:
|
||||
"""Return the UDP pipeline factory."""
|
||||
|
||||
def _voice_assistant_udp_server(entry):
|
||||
entry_data = DomainData.get(hass).get_entry_data(entry)
|
||||
|
||||
server: VoiceAssistantUDPPipeline = None
|
||||
|
||||
def handle_finished():
|
||||
nonlocal server
|
||||
assert server is not None
|
||||
server.close()
|
||||
|
||||
server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished)
|
||||
return server # noqa: RET504
|
||||
|
||||
return _voice_assistant_udp_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_api_pipeline(
|
||||
hass: HomeAssistant,
|
||||
mock_client,
|
||||
mock_voice_assistant_api_entry,
|
||||
) -> VoiceAssistantAPIPipeline:
|
||||
"""Return the API Pipeline factory."""
|
||||
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry)
|
||||
return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_udp_pipeline_v1(
|
||||
voice_assistant_udp_pipeline,
|
||||
mock_voice_assistant_v1_entry,
|
||||
) -> VoiceAssistantUDPPipeline:
|
||||
"""Return the UDP pipeline."""
|
||||
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_udp_pipeline_v2(
|
||||
voice_assistant_udp_pipeline,
|
||||
mock_voice_assistant_v2_entry,
|
||||
) -> VoiceAssistantUDPPipeline:
|
||||
"""Return the UDP pipeline."""
|
||||
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_wav() -> bytes:
|
||||
"""Return one second of empty WAV audio."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
return wav_io.getvalue()
|
||||
|
||||
|
||||
async def test_pipeline_events(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test that the pipeline function is called."""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
assert device_id == "mock-device-id"
|
||||
|
||||
event_callback = kwargs["event_callback"]
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.WAKE_WORD_END,
|
||||
data={"wake_word_output": {}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake events
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_START,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_START,
|
||||
data={"tts_input": _TEST_OUTPUT_TEXT},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": _TEST_OUTPUT_URL}},
|
||||
)
|
||||
)
|
||||
|
||||
def handle_event(
|
||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert data is not None
|
||||
assert data["text"] == _TEST_INPUT_TEXT
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert data is not None
|
||||
assert data["text"] == _TEST_OUTPUT_TEXT
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert data is not None
|
||||
assert data["url"] == _TEST_OUTPUT_URL
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||
assert data is None
|
||||
|
||||
voice_assistant_udp_pipeline_v1.handle_event = handle_event
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
voice_assistant_udp_pipeline_v1.transport = Mock()
|
||||
|
||||
await voice_assistant_udp_pipeline_v1.run_pipeline(
|
||||
device_id="mock-device-id", conversation_id=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_udp_server(
|
||||
unused_udp_port_factory: Callable[[], int],
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test the UDP server runs and queues incoming data."""
|
||||
port_to_use = unused_udp_port_factory()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
||||
):
|
||||
port = await voice_assistant_udp_pipeline_v1.start_server()
|
||||
assert port == port_to_use
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
|
||||
sock.sendto(b"test", ("127.0.0.1", port))
|
||||
|
||||
# Give the socket some time to send/receive the data
|
||||
async with asyncio.timeout(1):
|
||||
while voice_assistant_udp_pipeline_v1.queue.qsize() == 0:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
|
||||
|
||||
voice_assistant_udp_pipeline_v1.stop()
|
||||
voice_assistant_udp_pipeline_v1.close()
|
||||
|
||||
assert voice_assistant_udp_pipeline_v1.transport.is_closing()
|
||||
|
||||
|
||||
async def test_udp_server_queue(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test the UDP server queues incoming data."""
|
||||
|
||||
voice_assistant_udp_pipeline_v1.started = True
|
||||
|
||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
|
||||
|
||||
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
|
||||
|
||||
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
||||
|
||||
async for data in voice_assistant_udp_pipeline_v1._iterate_packets():
|
||||
assert data == bytes(1024)
|
||||
break
|
||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed
|
||||
|
||||
voice_assistant_udp_pipeline_v1.stop()
|
||||
assert (
|
||||
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
||||
) # An empty message added by stop
|
||||
|
||||
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
||||
assert (
|
||||
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
||||
) # No new messages added after stop
|
||||
|
||||
voice_assistant_udp_pipeline_v1.close()
|
||||
|
||||
# Stopping the UDP server should cause _iterate_packets to break out
|
||||
# immediately without yielding any data.
|
||||
has_data = False
|
||||
async for _data in voice_assistant_udp_pipeline_v1._iterate_packets():
|
||||
has_data = True
|
||||
|
||||
assert not has_data, "Server was stopped"
|
||||
|
||||
|
||||
async def test_api_pipeline_queue(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the API pipeline queues incoming data."""
|
||||
|
||||
voice_assistant_api_pipeline.started = True
|
||||
|
||||
assert voice_assistant_api_pipeline.queue.qsize() == 0
|
||||
|
||||
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
||||
assert voice_assistant_api_pipeline.queue.qsize() == 1
|
||||
|
||||
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
||||
assert voice_assistant_api_pipeline.queue.qsize() == 2
|
||||
|
||||
async for data in voice_assistant_api_pipeline._iterate_packets():
|
||||
assert data == bytes(1024)
|
||||
break
|
||||
assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed
|
||||
|
||||
voice_assistant_api_pipeline.stop()
|
||||
assert (
|
||||
voice_assistant_api_pipeline.queue.qsize() == 2
|
||||
) # An empty message added by stop
|
||||
|
||||
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
||||
assert (
|
||||
voice_assistant_api_pipeline.queue.qsize() == 2
|
||||
) # No new messages added after stop
|
||||
|
||||
# Stopping the API Pipeline should cause _iterate_packets to break out
|
||||
# immediately without yielding any data.
|
||||
has_data = False
|
||||
async for _data in voice_assistant_api_pipeline._iterate_packets():
|
||||
has_data = True
|
||||
|
||||
assert not has_data, "Pipeline was stopped"
|
||||
|
||||
|
||||
async def test_error_calls_handle_finished(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test that the handle_finished callback is called when an error occurs."""
|
||||
voice_assistant_udp_pipeline_v1.handle_finished = Mock()
|
||||
|
||||
voice_assistant_udp_pipeline_v1.error_received(Exception())
|
||||
|
||||
voice_assistant_udp_pipeline_v1.handle_finished.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_udp_server_multiple(
|
||||
unused_udp_port_factory: Callable[[], int],
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test that the UDP server raises an error if started twice."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
):
|
||||
await voice_assistant_udp_pipeline_v1.start_server()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
await voice_assistant_udp_pipeline_v1.start_server()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_udp_server_after_stopped(
|
||||
unused_udp_port_factory: Callable[[], int],
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test that the UDP server raises an error if started after stopped."""
|
||||
voice_assistant_udp_pipeline_v1.close()
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
await voice_assistant_udp_pipeline_v1.start_server()
|
||||
|
||||
|
||||
async def test_events_converted_correctly(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the pipeline events produce the correct data to send to the device."""
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts",
|
||||
):
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_START,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "text"}},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": "text"}
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_START,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_END,
|
||||
data={
|
||||
"intent_output": {
|
||||
"conversation_id": "conversation-id",
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
|
||||
{"conversation_id": "conversation-id"},
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_START,
|
||||
data={"tts_input": "text"},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": "text"}
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": "url", "media_id": "media-id"}},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, {"url": "url"}
|
||||
)
|
||||
|
||||
|
||||
async def test_unknown_event_type(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the API pipeline does not call handle_event for unknown events."""
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type="unknown-event",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
assert not voice_assistant_api_pipeline.handle_event.called
|
||||
|
||||
|
||||
async def test_error_event_type(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the API pipeline calls event handler with error."""
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.ERROR,
|
||||
data={"code": "code", "message": "message"},
|
||||
)
|
||||
)
|
||||
|
||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{"code": "code", "message": "message"},
|
||||
)
|
||||
|
||||
|
||||
async def test_send_tts_not_called(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test the UDP server with a v1 device does not call _send_tts."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||
) as mock_send_tts:
|
||||
voice_assistant_udp_pipeline_v1._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_send_tts_called_udp(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||
) -> None:
|
||||
"""Test the UDP server with a v2 device calls _send_tts."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||
) as mock_send_tts:
|
||||
voice_assistant_udp_pipeline_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
||||
|
||||
|
||||
async def test_send_tts_called_api(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the API pipeline calls _send_tts."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||
) as mock_send_tts:
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
||||
|
||||
|
||||
async def test_send_tts_not_called_when_empty(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test the pipelines do not call _send_tts when the output is empty."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||
) as mock_send_tts:
|
||||
voice_assistant_udp_pipeline_v1._event_callback(
|
||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||
)
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
voice_assistant_udp_pipeline_v2._event_callback(
|
||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||
)
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||
)
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_send_tts_udp(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", mock_wav),
|
||||
):
|
||||
voice_assistant_udp_pipeline_v2.started = True
|
||||
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
with patch.object(
|
||||
voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False
|
||||
):
|
||||
voice_assistant_udp_pipeline_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {
|
||||
"media_id": _TEST_MEDIA_ID,
|
||||
"url": _TEST_OUTPUT_URL,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
||||
|
||||
voice_assistant_udp_pipeline_v2.transport.sendto.assert_called()
|
||||
|
||||
|
||||
async def test_send_tts_api(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", mock_wav),
|
||||
):
|
||||
voice_assistant_api_pipeline.started = True
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {
|
||||
"media_id": _TEST_MEDIA_ID,
|
||||
"url": _TEST_OUTPUT_URL,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await voice_assistant_api_pipeline._tts_done.wait()
|
||||
|
||||
mock_client.send_voice_assistant_audio.assert_called()
|
||||
|
||||
|
||||
async def test_send_tts_wrong_sample_rate(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test that only 16000Hz audio will be streamed."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(22050)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
wav_bytes = wav_io.getvalue()
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", wav_bytes),
|
||||
):
|
||||
voice_assistant_api_pipeline.started = True
|
||||
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert voice_assistant_api_pipeline._tts_task is not None
|
||||
with pytest.raises(ValueError):
|
||||
await voice_assistant_api_pipeline._tts_task
|
||||
|
||||
|
||||
async def test_send_tts_wrong_format(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test that only WAV audio will be streamed."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("raw", bytes(1024)),
|
||||
),
|
||||
):
|
||||
voice_assistant_api_pipeline.started = True
|
||||
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_api_pipeline._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert voice_assistant_api_pipeline._tts_task is not None
|
||||
with pytest.raises(ValueError):
|
||||
await voice_assistant_api_pipeline._tts_task
|
||||
|
||||
|
||||
async def test_send_tts_not_started(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test the UDP server does not call sendto when not started."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", mock_wav),
|
||||
):
|
||||
voice_assistant_udp_pipeline_v2.started = False
|
||||
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_pipeline_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
||||
|
||||
voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called()
|
||||
|
||||
|
||||
async def test_send_tts_transport_none(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||
mock_wav: bytes,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test the UDP server does not call sendto when transport is None."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", mock_wav),
|
||||
):
|
||||
voice_assistant_udp_pipeline_v2.started = True
|
||||
voice_assistant_udp_pipeline_v2.transport = None
|
||||
|
||||
voice_assistant_udp_pipeline_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
||||
|
||||
assert "No transport to send audio to" in caplog.text
|
||||
|
||||
|
||||
async def test_wake_word(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test that the pipeline is set to start with Wake word."""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
|
||||
assert start_stage == PipelineStage.WAKE_WORD
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch("asyncio.Event.wait"), # TTS wait event
|
||||
):
|
||||
await voice_assistant_api_pipeline.run_pipeline(
|
||||
device_id="mock-device-id",
|
||||
conversation_id=None,
|
||||
flags=2,
|
||||
)
|
||||
|
||||
|
||||
async def test_wake_word_exception(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test that the pipeline is set to start with Wake word."""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
|
||||
def handle_event(
|
||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||
assert data is not None
|
||||
assert data["code"] == "pipeline-not-found"
|
||||
assert data["message"] == "Pipeline not found"
|
||||
|
||||
voice_assistant_api_pipeline.handle_event = handle_event
|
||||
|
||||
await voice_assistant_api_pipeline.run_pipeline(
|
||||
device_id="mock-device-id",
|
||||
conversation_id=None,
|
||||
flags=2,
|
||||
)
|
||||
|
||||
|
||||
async def test_wake_word_abort_exception(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test that the pipeline is set to start with Wake word."""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
raise WakeWordDetectionAborted
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event,
|
||||
):
|
||||
await voice_assistant_api_pipeline.run_pipeline(
|
||||
device_id="mock-device-id",
|
||||
conversation_id=None,
|
||||
flags=2,
|
||||
)
|
||||
|
||||
mock_handle_event.assert_not_called()
|
||||
|
||||
|
||||
async def test_timer_events(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that injecting timer events results in the correct api client calls."""
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.TIMERS
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
dev = device_registry.async_get_device(
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||
)
|
||||
|
||||
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
|
||||
await intent_helper.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent_helper.INTENT_START_TIMER,
|
||||
{
|
||||
"name": {"value": "test timer"},
|
||||
"hours": {"value": 1},
|
||||
"minutes": {"value": 2},
|
||||
"seconds": {"value": 3},
|
||||
},
|
||||
device_id=dev.id,
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
|
||||
ANY,
|
||||
"test timer",
|
||||
total_seconds,
|
||||
total_seconds,
|
||||
True,
|
||||
)
|
||||
|
||||
# Increase timer beyond original time and check total_seconds has increased
|
||||
mock_client.send_voice_assistant_timer_event.reset_mock()
|
||||
|
||||
total_seconds += 5 * 60
|
||||
await intent_helper.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent_helper.INTENT_INCREASE_TIMER,
|
||||
{
|
||||
"name": {"value": "test timer"},
|
||||
"minutes": {"value": 5},
|
||||
},
|
||||
device_id=dev.id,
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
|
||||
ANY,
|
||||
"test timer",
|
||||
total_seconds,
|
||||
ANY,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
async def test_unknown_timer_event(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that unknown (new) timer event types do not result in api calls."""
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.TIMERS
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
dev = device_registry.async_get_device(
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant._TIMER_EVENT_TYPES.from_hass",
|
||||
side_effect=KeyError,
|
||||
):
|
||||
await intent_helper.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent_helper.INTENT_START_TIMER,
|
||||
{
|
||||
"name": {"value": "test timer"},
|
||||
"hours": {"value": 1},
|
||||
"minutes": {"value": 2},
|
||||
"seconds": {"value": 3},
|
||||
},
|
||||
device_id=dev.id,
|
||||
)
|
||||
|
||||
mock_client.send_voice_assistant_timer_event.assert_not_called()
|
||||
|
||||
|
||||
async def test_invalid_pipeline_id(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||
) -> None:
|
||||
"""Test that the pipeline is set to start with Wake word."""
|
||||
|
||||
invalid_pipeline_id = "invalid-pipeline-id"
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
raise PipelineNotFound(
|
||||
"pipeline_not_found", f"Pipeline {invalid_pipeline_id} not found"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
|
||||
def handle_event(
|
||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||
assert data is not None
|
||||
assert data["code"] == "pipeline_not_found"
|
||||
assert data["message"] == f"Pipeline {invalid_pipeline_id} not found"
|
||||
|
||||
voice_assistant_api_pipeline.handle_event = handle_event
|
||||
|
||||
await voice_assistant_api_pipeline.run_pipeline(
|
||||
device_id="mock-device-id",
|
||||
conversation_id=None,
|
||||
flags=2,
|
||||
)
|
|
@ -14,6 +14,9 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.tts.conftest import (
|
||||
mock_tts_cache_dir_fixture_autouse, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
File diff suppressed because one or more lines are too long
|
@ -3,15 +3,26 @@
|
|||
import asyncio
|
||||
import io
|
||||
from pathlib import Path
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from voip_utils import CallInfo
|
||||
|
||||
from homeassistant.components import assist_pipeline, voip
|
||||
from homeassistant.components.voip.devices import VoIPDevice
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, voip
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.components.voip import HassVoipDatagramProtocol
|
||||
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
||||
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
||||
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, Platform
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
|
|||
return wav_io.getvalue()
|
||||
|
||||
|
||||
def async_get_satellite_entity(
|
||||
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||
) -> AssistSatelliteEntity | None:
|
||||
"""Get Assist satellite entity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
|
||||
)
|
||||
if satellite_entity_id is None:
|
||||
return None
|
||||
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||
assist_satellite.DOMAIN
|
||||
]
|
||||
return component.get_entity(satellite_entity_id)
|
||||
|
||||
|
||||
async def test_is_valid_call(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that a call is now allowed from an unknown device."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol = HassVoipDatagramProtocol(hass, voip_devices)
|
||||
assert not protocol.is_valid_call(call_info)
|
||||
|
||||
ent_reg = er.async_get(hass)
|
||||
allowed_call_entity_id = ent_reg.async_get_entity_id(
|
||||
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
|
||||
)
|
||||
assert allowed_call_entity_id is not None
|
||||
state = hass.states.get(allowed_call_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
|
||||
# Allow calls
|
||||
hass.states.async_set(allowed_call_entity_id, STATE_ON)
|
||||
assert protocol.is_valid_call(call_info)
|
||||
|
||||
|
||||
async def test_calls_not_allowed(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when calls aren't allowed."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
# Test the playback
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
def send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be problem.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
protocol.transport = Mock()
|
||||
protocol.loop_delay = 0
|
||||
with patch.object(protocol, "send_audio", send_audio):
|
||||
protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
||||
|
||||
async def test_pipeline_not_found(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when a pipeline isn't found."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
|
||||
):
|
||||
protocol: PreRecordMessageProtocol = make_protocol(
|
||||
hass, voip_devices, call_info
|
||||
)
|
||||
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
|
||||
async def test_satellite_prepared(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that satellite is prepared for a call."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
pipeline = assist_pipeline.Pipeline(
|
||||
conversation_engine="test",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test",
|
||||
stt_engine="test",
|
||||
stt_language="en",
|
||||
tts_engine="test",
|
||||
tts_language="en",
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline",
|
||||
return_value=pipeline,
|
||||
),
|
||||
):
|
||||
protocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert protocol == satellite
|
||||
|
||||
|
||||
async def test_pipeline(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that pipeline function is called from RTP protocol."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
voip_user_id = satellite.config_entry.data["user"]
|
||||
assert voip_user_id
|
||||
|
||||
return 0
|
||||
# Satellite is muted until a call begins
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
# Used to test that audio queue is cleared before pipeline starts
|
||||
bad_chunk = bytes([1, 2, 3, 4])
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
|
||||
):
|
||||
assert context.user_id == voip_user_id
|
||||
assert device_id == voip_device.device_id
|
||||
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
assert _chunk != bad_chunk
|
||||
assert chunk != bad_chunk
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Test empty data
|
||||
event_callback(
|
||||
|
@ -71,6 +229,38 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_START,
|
||||
data={"engine": "test", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "fake-text",
|
||||
"conversation_id": None,
|
||||
"device_id": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -83,6 +273,21 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# Fake tts result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"voice": "test",
|
||||
"tts_input": "fake-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Proceed with media output
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -91,6 +296,18 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.RUN_END
|
||||
)
|
||||
)
|
||||
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
done.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
|
@ -100,102 +317,56 @@ async def test_pipeline(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite._tones = Tones(0)
|
||||
satellite.transport = Mock()
|
||||
|
||||
satellite.connection_made(satellite.transport)
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
# Ensure audio queue is cleared before pipeline starts
|
||||
rtp_protocol._audio_queue.put_nowait(bad_chunk)
|
||||
satellite._audio_queue.put_nowait(bad_chunk)
|
||||
|
||||
def send_audio(*args, **kwargs):
|
||||
# Test finished successfully
|
||||
done.set()
|
||||
# Don't send audio
|
||||
pass
|
||||
|
||||
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
||||
satellite.send_audio = Mock(side_effect=send_audio)
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes aggressive VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
"""Test timeout during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
pipeline_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
# Finished speaking
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
async def test_stt_stream_timeout(
|
||||
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
|
||||
) -> None:
|
||||
"""Test timeout in STT stream during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
|
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
pass
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
audio_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
satellite._tones = Tones(0)
|
||||
satellite._audio_chunk_timeout = 0.001
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
satellite.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
|
||||
async def test_tts_timeout(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will time out based on its length."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -278,15 +448,7 @@ async def test_tts_timeout(
|
|||
|
||||
tone_bytes = bytes([1, 2, 3, 4])
|
||||
|
||||
def send_audio(audio_bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
||||
# Block here to force a timeout in _send_tts
|
||||
time.sleep(2)
|
||||
|
||||
async def async_send_audio(audio_bytes, **kwargs):
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
@ -303,37 +465,22 @@ async def test_tts_timeout(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
tts_extra_timeout=0.001,
|
||||
listening_tone_enabled=True,
|
||||
processing_tone_enabled=True,
|
||||
error_tone_enabled=True,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
|
||||
)
|
||||
rtp_protocol._tone_bytes = tone_bytes
|
||||
rtp_protocol._processing_bytes = tone_bytes
|
||||
rtp_protocol._error_bytes = tone_bytes
|
||||
rtp_protocol.transport = Mock()
|
||||
rtp_protocol.send_audio = Mock()
|
||||
satellite._tts_extra_timeout = 0.001
|
||||
for tone in Tones:
|
||||
satellite._tone_bytes[tone] = tone_bytes
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
satellite.transport = Mock()
|
||||
satellite.send_audio = Mock()
|
||||
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -342,17 +489,17 @@ async def test_tts_timeout(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -361,26 +508,34 @@ async def test_tts_timeout(
|
|||
|
||||
async def test_tts_wrong_extension(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
|
|||
|
||||
async def test_tts_wrong_wav_format(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio with a specific format."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
async def test_empty_tts_output(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will not stream when output is empty."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -605,37 +754,78 @@ async def test_empty_tts_output(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||
) as mock_send_tts,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to finish
|
||||
async with asyncio.timeout(1):
|
||||
await rtp_protocol._tts_done.wait()
|
||||
await satellite._tts_done.wait()
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_pipeline_error(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pipeline error causes the error tone to be played."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
# Fake error
|
||||
event_callback = kwargs["event_callback"]
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.ERROR,
|
||||
data={"code": "error-code", "message": "error message"},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be error.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
satellite._tones = Tones.ERROR
|
||||
satellite.transport = Mock()
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for error tone to be played
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.wyoming import DOMAIN
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from homeassistant.components import assist_pipeline
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
|
|||
timer_handle.cancel()
|
||||
timer_handle2.cancel()
|
||||
timer_handle3.cancel()
|
||||
|
||||
|
||||
async def test_queue_to_iterable() -> None:
|
||||
"""Test queue_to_iterable."""
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
expected_items = list(range(10))
|
||||
|
||||
for i in expected_items:
|
||||
await queue.put(i)
|
||||
|
||||
# Will terminate the stream
|
||||
await queue.put(None)
|
||||
|
||||
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
|
||||
|
||||
assert expected_items == actual_items
|
||||
|
||||
# Check timeout
|
||||
assert queue.empty()
|
||||
|
||||
# Time out on first item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check timeout on second item
|
||||
assert queue.empty()
|
||||
await queue.put(12345)
|
||||
|
||||
# Time out on second item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
if item != 12345:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert queue.empty()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue