Add async_announce
This commit is contained in:
parent
d375bfaefe
commit
f0c49b3995
7 changed files with 383 additions and 172 deletions
|
@ -10,18 +10,14 @@ from homeassistant.helpers.typing import ConfigType
|
|||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||
from .models import (
|
||||
AssistSatelliteEntityFeature,
|
||||
AssistSatelliteState,
|
||||
PipelineRunConfig,
|
||||
)
|
||||
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteState",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityDescription",
|
||||
"PipelineRunConfig",
|
||||
"AssistSatelliteEntityFeature",
|
||||
]
|
||||
|
||||
|
@ -35,6 +31,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
await component.async_setup(config)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline import (
|
|||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
|
@ -28,11 +29,7 @@ from homeassistant.helpers.entity import EntityDescription
|
|||
from homeassistant.util import ulid
|
||||
|
||||
from .errors import SatelliteBusyError
|
||||
from .models import (
|
||||
AssistSatelliteEntityFeature,
|
||||
AssistSatelliteState,
|
||||
PipelineRunConfig,
|
||||
)
|
||||
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,6 +51,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_is_announcing: bool = False
|
||||
_tts_finished_event: asyncio.Event | None = None
|
||||
_wake_word_future: asyncio.Future[str | None] | None = None
|
||||
|
||||
|
@ -61,14 +59,10 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
"""Run when entity about to be added to hass."""
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
async def async_trigger_pipeline_on_satellite(
|
||||
self, run_config: PipelineRunConfig
|
||||
) -> None:
|
||||
"""Run a pipeline on the satellite with the configuration.
|
||||
|
||||
Requires TRIGGER_PIPELINE supported feature.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@property
|
||||
def is_announcing(self) -> bool:
|
||||
"""Returns true if currently announcing."""
|
||||
return self._is_announcing
|
||||
|
||||
async def async_announce(
|
||||
self,
|
||||
|
@ -76,72 +70,78 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
announce_media_id: str | None = None,
|
||||
pipeline_entity_id: str | None = None,
|
||||
) -> None:
|
||||
"""Play an announcement on the satellite."""
|
||||
if self._tts_finished_event is not None:
|
||||
raise SatelliteBusyError()
|
||||
"""Play an announcement on the satellite.
|
||||
|
||||
if not announce_media_id:
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
If announce_media_id is not provided, announce_text is synthesized to
|
||||
audio with the selected pipeline.
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
Calls _internal_async_announce with media id and expects it to block
|
||||
until the announcement is completed.
|
||||
"""
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
||||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
announce_text,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
tts_media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
tts_media_id,
|
||||
None,
|
||||
)
|
||||
announce_media_id = tts_media.url
|
||||
self._is_announcing = True
|
||||
|
||||
await self.async_trigger_pipeline_on_satellite(
|
||||
PipelineRunConfig(
|
||||
start_stage=PipelineStage.TTS,
|
||||
end_stage=PipelineStage.TTS,
|
||||
pipeline_entity_id=pipeline_entity_id,
|
||||
announce_text=announce_text,
|
||||
announce_media_id=announce_media_id,
|
||||
),
|
||||
try:
|
||||
if not announce_media_id:
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
announce_text,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
tts_media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
tts_media_id,
|
||||
None,
|
||||
)
|
||||
|
||||
# Resolve to full URL
|
||||
announce_media_id = async_process_play_media_url(
|
||||
self.hass, tts_media.url
|
||||
)
|
||||
|
||||
# Block until announcement is finished
|
||||
await self._internal_async_announce(announce_media_id)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
|
||||
async def _internal_async_announce(self, media_id: str) -> None:
|
||||
"""Announce the media URL on the satellite and returns when finished."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_intercepting_wake_word(self) -> bool:
|
||||
"""Return true if next wake word will be intercepted."""
|
||||
return (self._wake_word_future is not None) and (
|
||||
not self._wake_word_future.cancelled()
|
||||
)
|
||||
|
||||
# Wait for device to report that announcement has finished
|
||||
if self._tts_finished_event is not None:
|
||||
try:
|
||||
await self._tts_finished_event.wait()
|
||||
finally:
|
||||
self._tts_finished_event = None
|
||||
|
||||
async def async_wait_wake(
|
||||
self,
|
||||
announce_text: str | None = None,
|
||||
announce_media_id: str | None = None,
|
||||
pipeline_entity_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Block until a wake word is detected from the satellite.
|
||||
async def async_intercept_wake_word(self) -> str | None:
|
||||
"""Intercept the next wake word from the satellite.
|
||||
|
||||
Returns the detected wake word phrase or None.
|
||||
"""
|
||||
if self._wake_word_future is not None:
|
||||
raise SatelliteBusyError()
|
||||
raise SatelliteBusyError
|
||||
|
||||
# Will cause next wake word to be intercepted in
|
||||
# _async_accept_pipeline_from_satellite
|
||||
self._wake_word_future = asyncio.Future()
|
||||
|
||||
try:
|
||||
if announce_text or announce_media_id:
|
||||
# Make announcement first
|
||||
await self.async_announce(
|
||||
announce_text or "", announce_media_id, pipeline_entity_id
|
||||
)
|
||||
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
||||
|
||||
try:
|
||||
return await self._wake_word_future
|
||||
finally:
|
||||
self._wake_word_future = None
|
||||
|
@ -157,12 +157,15 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
vad_sensitivity_entity_id: str | None = None,
|
||||
wake_word_phrase: str | None = None,
|
||||
) -> None:
|
||||
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||
if (self._wake_word_future is not None) and (
|
||||
not self._wake_word_future.cancelled()
|
||||
):
|
||||
# Intercepting wake word
|
||||
_LOGGER.debug("Intercepted wake word: %s", wake_word_phrase)
|
||||
"""Trigger an Assist pipeline in Home Assistant from a satellite."""
|
||||
if self.is_intercepting_wake_word:
|
||||
# Intercepting wake word and immediately end pipeline
|
||||
_LOGGER.debug(
|
||||
"Intercepted wake word: %s (entity_id=%s)",
|
||||
wake_word_phrase,
|
||||
self.entity_id,
|
||||
)
|
||||
assert self._wake_word_future is not None
|
||||
self._wake_word_future.set_result(wake_word_phrase)
|
||||
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||
return
|
||||
|
@ -265,6 +268,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
self._tts_finished_event.set()
|
||||
|
||||
def _resolve_pipeline(self, pipeline_entity_id: str | None) -> str | None:
|
||||
"""Resolve pipeline from select entity to id."""
|
||||
if not pipeline_entity_id:
|
||||
return None
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
"""Models for assist satellite."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntFlag, StrEnum
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineStage
|
||||
|
||||
|
||||
class AssistSatelliteState(StrEnum):
|
||||
"""Valid states of an Assist satellite entity."""
|
||||
|
@ -25,25 +22,5 @@ class AssistSatelliteState(StrEnum):
|
|||
class AssistSatelliteEntityFeature(IntFlag):
|
||||
"""Supported features of Assist satellite entity."""
|
||||
|
||||
TRIGGER_PIPELINE = 1
|
||||
"""Device supports remote triggering of a pipeline."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRunConfig:
|
||||
"""Configuration for a satellite pipeline run."""
|
||||
|
||||
start_stage: PipelineStage
|
||||
"""Start stage of the pipeline to run."""
|
||||
|
||||
end_stage: PipelineStage
|
||||
"""End stage of the pipeline to run."""
|
||||
|
||||
pipeline_entity_id: str | None = None
|
||||
"""Id of the entity with which pipeline to run."""
|
||||
|
||||
announce_text: str | None = None
|
||||
"""Text to announce using text-to-speech."""
|
||||
|
||||
announce_media_id: str | None = None
|
||||
"""Media id to announce."""
|
||||
ANNOUNCE = 1
|
||||
"""Device supports remotely triggered announcements."""
|
||||
|
|
80
homeassistant/components/assist_satellite/websocket_api.py
Normal file
80
homeassistant/components/assist_satellite/websocket_api.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
"""Assist satellite Websocket API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
from .models import AssistSatelliteEntityFeature
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||
websocket_api.async_register_command(hass, websocket_announce)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_satellite/intercept_wake_word",
|
||||
vol.Required("entity_id"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Intercept the next wake word from a satellite."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
satellite = component.get_entity(msg["entity_id"])
|
||||
if satellite is None:
|
||||
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
|
||||
return
|
||||
|
||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_satellite/announce",
|
||||
vol.Required("entity_id"): str,
|
||||
vol.Required(vol.Any("text", "media_id")): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_announce(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Announce text or a media id on the satellite."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
satellite = component.get_entity(msg["entity_id"])
|
||||
if satellite is None:
|
||||
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
|
||||
return
|
||||
|
||||
if (satellite.supported_features is None) or (
|
||||
not (satellite.supported_features & AssistSatelliteEntityFeature.ANNOUNCE)
|
||||
):
|
||||
connection.send_message(
|
||||
websocket_api.error_message(
|
||||
msg["id"], ERR_NOT_SUPPORTED, "Satellite does not support announcements"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await satellite.async_announce(msg.get("text", ""), msg.get("media_id"))
|
||||
connection.send_result(msg["id"], {})
|
|
@ -33,7 +33,6 @@ from homeassistant.const import EntityCategory, Platform
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import EsphomeAssistEntity
|
||||
|
@ -102,9 +101,7 @@ class EsphomeAssistSatellite(
|
|||
translation_key="assist_satellite",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_supported_features = (
|
||||
assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
|
||||
)
|
||||
_attr_supported_features = assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -123,7 +120,6 @@ class EsphomeAssistSatellite(
|
|||
self._is_running: bool = True
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._pipeline_runs: dict[str, assist_satellite.PipelineRunConfig] = {}
|
||||
self._tts_streaming_task: asyncio.Task | None = None
|
||||
self._udp_server: VoiceAssistantUDPServer | None = None
|
||||
|
||||
|
@ -166,28 +162,13 @@ class EsphomeAssistSatellite(
|
|||
)
|
||||
)
|
||||
|
||||
# async def test() -> None:
|
||||
# await asyncio.sleep(5)
|
||||
# await self.async_announce("This is a test.")
|
||||
|
||||
self.config_entry.async_create_background_task(self.hass, test(), "test")
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
self._is_running = False
|
||||
self._stop_pipeline()
|
||||
|
||||
async def async_trigger_pipeline_on_satellite(
|
||||
self,
|
||||
run_config: assist_satellite.PipelineRunConfig,
|
||||
) -> None:
|
||||
"""Triggers a remote pipeline run on the satellite."""
|
||||
pipeline_run_id = ulid_now()
|
||||
self._pipeline_runs[pipeline_run_id] = run_config
|
||||
self.cli.trigger_voice_assistant_pipeline(
|
||||
pipeline_run_id, run_config.announce_text, run_config.announce_media_id
|
||||
)
|
||||
_LOGGER.debug("Triggered remote pipeline run (id=%s)", pipeline_run_id)
|
||||
async def _internal_async_announce(self, media_id: str) -> None:
|
||||
self.cli.send_voice_assistant_announce(media_id)
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
@ -257,7 +238,6 @@ class EsphomeAssistSatellite(
|
|||
flags: int,
|
||||
audio_settings: VoiceAssistantAudioSettings,
|
||||
wake_word_phrase: str | None,
|
||||
pipeline_run_id: str | None,
|
||||
) -> int | None:
|
||||
"""Handle pipeline run request."""
|
||||
# Clear audio queue
|
||||
|
@ -265,7 +245,7 @@ class EsphomeAssistSatellite(
|
|||
await self._audio_queue.get()
|
||||
|
||||
if self._tts_streaming_task is not None:
|
||||
# Cancel any exiting TTS response
|
||||
# Cancel current TTS response
|
||||
self._tts_streaming_task.cancel()
|
||||
self._tts_streaming_task = None
|
||||
|
||||
|
@ -290,28 +270,19 @@ class EsphomeAssistSatellite(
|
|||
DOMAIN,
|
||||
f"{self.entry_data.device_info.mac_address}-pipeline",
|
||||
)
|
||||
vad_sensitivity_id = ent_reg.async_get_entity_id(
|
||||
vad_sensitivity_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.SELECT,
|
||||
DOMAIN,
|
||||
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
||||
)
|
||||
|
||||
# Determine if this pipeline was triggered remotely or on-device
|
||||
if (pipeline_run_id is not None) and (
|
||||
(run_config := self._pipeline_runs.pop(pipeline_run_id)) is not None
|
||||
):
|
||||
# HA triggered pipeline
|
||||
start_stage = run_config.start_stage
|
||||
end_stage = run_config.end_stage
|
||||
pipeline_entity_id = run_config.pipeline_entity_id or pipeline_entity_id
|
||||
# Device triggered pipeline (wake word, etc.)
|
||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||
start_stage = PipelineStage.WAKE_WORD
|
||||
else:
|
||||
# Device triggered pipeline (wake word, etc.)
|
||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||
start_stage = PipelineStage.WAKE_WORD
|
||||
else:
|
||||
start_stage = PipelineStage.STT
|
||||
start_stage = PipelineStage.STT
|
||||
|
||||
end_stage = PipelineStage.TTS
|
||||
end_stage = PipelineStage.TTS
|
||||
|
||||
# Run the pipeline
|
||||
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||
|
@ -323,6 +294,7 @@ class EsphomeAssistSatellite(
|
|||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
pipeline_entity_id=pipeline_entity_id,
|
||||
vad_sensitivity_entity_id=vad_sensitivity_entity_id,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
),
|
||||
"esphome_assist_satellite_pipeline",
|
||||
|
@ -391,37 +363,34 @@ class EsphomeAssistSatellite(
|
|||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
if (
|
||||
(wav_file.getframerate() != sample_rate)
|
||||
or (wav_file.getsampwidth() != sample_width)
|
||||
or (wav_file.getnchannels() != sample_channels)
|
||||
):
|
||||
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
|
||||
return
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
if (
|
||||
(wav_file.getframerate() != sample_rate)
|
||||
or (wav_file.getsampwidth() != sample_width)
|
||||
or (wav_file.getnchannels() != sample_channels)
|
||||
):
|
||||
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
|
||||
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
|
||||
|
||||
while True:
|
||||
chunk = wav_file.readframes(samples_per_chunk)
|
||||
if not chunk:
|
||||
break
|
||||
while True:
|
||||
chunk = wav_file.readframes(samples_per_chunk)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if self._udp_server is not None:
|
||||
self._udp_server.send_audio_bytes(chunk)
|
||||
else:
|
||||
self.cli.send_voice_assistant_audio(chunk)
|
||||
if self._udp_server is not None:
|
||||
self._udp_server.send_audio_bytes(chunk)
|
||||
else:
|
||||
self.cli.send_voice_assistant_audio(chunk)
|
||||
|
||||
# Wait for 90% of the duration of the audio that was
|
||||
# sent for it to be played. This will overrun the
|
||||
# device's buffer for very long audio, so using a media
|
||||
# player is preferred.
|
||||
samples_in_chunk = len(chunk) // (
|
||||
sample_width * sample_channels
|
||||
)
|
||||
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||
# Wait for 90% of the duration of the audio that was
|
||||
# sent for it to be played. This will overrun the
|
||||
# device's buffer for very long audio, so using a media
|
||||
# player is preferred.
|
||||
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||
except asyncio.CancelledError:
|
||||
return # Don't trigger state change
|
||||
finally:
|
||||
|
@ -433,7 +402,7 @@ class EsphomeAssistSatellite(
|
|||
self.tts_response_finished()
|
||||
|
||||
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
|
||||
"""Yields audio chunks from the queue until None."""
|
||||
"""Yield audio chunks from the queue until None."""
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if not chunk:
|
||||
|
@ -442,12 +411,12 @@ class EsphomeAssistSatellite(
|
|||
yield chunk
|
||||
|
||||
def _stop_pipeline(self) -> None:
|
||||
"""Requests pipeline to be stopped."""
|
||||
"""Request pipeline to be stopped."""
|
||||
self._audio_queue.put_nowait(None)
|
||||
_LOGGER.debug("Requested pipeline stop")
|
||||
|
||||
async def _start_udp_server(self) -> int:
|
||||
"""Starts a UDP server on a random free port."""
|
||||
"""Start a UDP server on a random free port."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind(("", 0)) # random free port
|
||||
|
@ -466,7 +435,7 @@ class EsphomeAssistSatellite(
|
|||
return cast(int, sock.getsockname()[1])
|
||||
|
||||
def _stop_udp_server(self) -> None:
|
||||
"""Stops the UDP server if it's running."""
|
||||
"""Stop the UDP server if it's running."""
|
||||
if self._udp_server is None:
|
||||
return
|
||||
|
||||
|
@ -488,6 +457,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
def __init__(
|
||||
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize protocol."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._audio_queue = audio_queue
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from homeassistant.components.assist_pipeline import PipelineEvent
|
|||
from homeassistant.components.assist_satellite import (
|
||||
DOMAIN as AS_DOMAIN,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityFeature,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -30,6 +31,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||
"""Mock Assist Satellite Entity."""
|
||||
|
||||
_attr_name = "Test Entity"
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
|
|
181
tests/components/assist_satellite/test_websocket_api.py
Normal file
181
tests/components/assist_satellite/test_websocket_api.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
"""Test the Assist Satellite websocket API."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteEntityFeature
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
ENTITY_ID = "assist_satellite.test_entity"
|
||||
|
||||
|
||||
async def audio_stream() -> AsyncIterable[bytes]:
|
||||
"""Empty audio stream."""
|
||||
yield b""
|
||||
|
||||
|
||||
async def test_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/intercept_wake_word command."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.validate",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_recognize_intent",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_text_to_speech",
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(entity, "on_pipeline_event") as mock_on_pipeline_event,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{"type": "assist_satellite/intercept_wake_word", "entity_id": ENTITY_ID}
|
||||
)
|
||||
|
||||
# Wait for interception to start
|
||||
while not entity.is_intercepting_wake_word:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Start a pipeline with a wake word
|
||||
await entity._async_accept_pipeline_from_satellite(
|
||||
audio_stream=audio_stream(),
|
||||
start_stage=PipelineStage.STT,
|
||||
end_stage=PipelineStage.TTS,
|
||||
wake_word_phrase="test wake word",
|
||||
)
|
||||
|
||||
# Verify that wake word was intercepted
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == {"wake_word_phrase": "test wake word"}
|
||||
|
||||
# Verify that only run end event was sent to pipeline
|
||||
mock_on_pipeline_event.assert_called_once_with(
|
||||
PipelineEvent(PipelineEventType.RUN_END, data=None, timestamp=ANY)
|
||||
)
|
||||
|
||||
|
||||
async def test_announce_not_supported(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/announce command with an entity that doesn't support announcements."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch.object(
|
||||
entity, "_attr_supported_features", AssistSatelliteEntityFeature(0)
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/announce",
|
||||
"entity_id": ENTITY_ID,
|
||||
"media_id": "test media id",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == ERR_NOT_SUPPORTED
|
||||
|
||||
|
||||
async def test_announce_media_id(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/announce command with media id."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
entity, "_internal_async_announce"
|
||||
) as mock_internal_async_announce,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/announce",
|
||||
"entity_id": ENTITY_ID,
|
||||
"media_id": "test media id",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
# Verify media id was passed through
|
||||
mock_internal_async_announce.assert_called_once_with("test media id")
|
||||
|
||||
|
||||
async def test_announce_text(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test assist_satellite/announce command with text."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(url="test media id", mime_type=""),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
|
||||
return_value="test media id",
|
||||
),
|
||||
patch.object(
|
||||
entity, "_internal_async_announce"
|
||||
) as mock_internal_async_announce,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/announce",
|
||||
"entity_id": ENTITY_ID,
|
||||
"text": "test text",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
# Verify media id was passed through
|
||||
mock_internal_async_announce.assert_called_once_with("test media id")
|
Loading…
Add table
Reference in a new issue