Add async_announce

This commit is contained in:
Michael Hansen 2024-08-29 15:41:23 -05:00
parent d375bfaefe
commit f0c49b3995
7 changed files with 383 additions and 172 deletions

View file

@ -10,18 +10,14 @@ from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import (
AssistSatelliteEntityFeature,
AssistSatelliteState,
PipelineRunConfig,
)
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
from .websocket_api import async_register_websocket_api
__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
"PipelineRunConfig",
"AssistSatelliteEntityFeature",
]
@ -35,6 +31,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER, DOMAIN, hass
)
await component.async_setup(config)
async_register_websocket_api(hass)
return True

View file

@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
@ -28,11 +29,7 @@ from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .errors import SatelliteBusyError
from .models import (
AssistSatelliteEntityFeature,
AssistSatelliteState,
PipelineRunConfig,
)
from .models import AssistSatelliteEntityFeature, AssistSatelliteState
_LOGGER = logging.getLogger(__name__)
@ -54,6 +51,7 @@ class AssistSatelliteEntity(entity.Entity):
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_is_announcing: bool = False
_tts_finished_event: asyncio.Event | None = None
_wake_word_future: asyncio.Future[str | None] | None = None
@ -61,14 +59,10 @@ class AssistSatelliteEntity(entity.Entity):
"""Run when entity about to be added to hass."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
async def async_trigger_pipeline_on_satellite(
self, run_config: PipelineRunConfig
) -> None:
"""Run a pipeline on the satellite with the configuration.
Requires TRIGGER_PIPELINE supported feature.
"""
raise NotImplementedError
@property
def is_announcing(self) -> bool:
"""Returns true if currently announcing."""
return self._is_announcing
async def async_announce(
self,
@ -76,10 +70,20 @@ class AssistSatelliteEntity(entity.Entity):
announce_media_id: str | None = None,
pipeline_entity_id: str | None = None,
) -> None:
"""Play an announcement on the satellite."""
if self._tts_finished_event is not None:
raise SatelliteBusyError()
"""Play an announcement on the satellite.
If announce_media_id is not provided, announce_text is synthesized to
audio with the selected pipeline.
Calls _internal_async_announce with media id and expects it to block
until the announcement is completed.
"""
if self._is_announcing:
raise SatelliteBusyError
self._is_announcing = True
try:
if not announce_media_id:
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline(pipeline_entity_id)
@ -101,47 +105,43 @@ class AssistSatelliteEntity(entity.Entity):
tts_media_id,
None,
)
announce_media_id = tts_media.url
await self.async_trigger_pipeline_on_satellite(
PipelineRunConfig(
start_stage=PipelineStage.TTS,
end_stage=PipelineStage.TTS,
pipeline_entity_id=pipeline_entity_id,
announce_text=announce_text,
announce_media_id=announce_media_id,
),
# Resolve to full URL
announce_media_id = async_process_play_media_url(
self.hass, tts_media.url
)
# Wait for device to report that announcement has finished
if self._tts_finished_event is not None:
try:
await self._tts_finished_event.wait()
# Block until announcement is finished
await self._internal_async_announce(announce_media_id)
finally:
self._tts_finished_event = None
self._is_announcing = False
async def async_wait_wake(
self,
announce_text: str | None = None,
announce_media_id: str | None = None,
pipeline_entity_id: str | None = None,
) -> str | None:
"""Block until a wake word is detected from the satellite.
async def _internal_async_announce(self, media_id: str) -> None:
"""Announce the media URL on the satellite and returns when finished."""
raise NotImplementedError
@property
def is_intercepting_wake_word(self) -> bool:
"""Return true if next wake word will be intercepted."""
return (self._wake_word_future is not None) and (
not self._wake_word_future.cancelled()
)
async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
Returns the detected wake word phrase or None.
"""
if self._wake_word_future is not None:
raise SatelliteBusyError()
raise SatelliteBusyError
# Will cause next wake word to be intercepted in
# _async_accept_pipeline_from_satellite
self._wake_word_future = asyncio.Future()
try:
if announce_text or announce_media_id:
# Make announcement first
await self.async_announce(
announce_text or "", announce_media_id, pipeline_entity_id
)
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
try:
return await self._wake_word_future
finally:
self._wake_word_future = None
@ -157,12 +157,15 @@ class AssistSatelliteEntity(entity.Entity):
vad_sensitivity_entity_id: str | None = None,
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
if (self._wake_word_future is not None) and (
not self._wake_word_future.cancelled()
):
# Intercepting wake word
_LOGGER.debug("Intercepted wake word: %s", wake_word_phrase)
"""Trigger an Assist pipeline in Home Assistant from a satellite."""
if self.is_intercepting_wake_word:
# Intercepting wake word and immediately end pipeline
_LOGGER.debug(
"Intercepted wake word: %s (entity_id=%s)",
wake_word_phrase,
self.entity_id,
)
assert self._wake_word_future is not None
self._wake_word_future.set_result(wake_word_phrase)
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return
@ -265,6 +268,7 @@ class AssistSatelliteEntity(entity.Entity):
self._tts_finished_event.set()
def _resolve_pipeline(self, pipeline_entity_id: str | None) -> str | None:
"""Resolve pipeline from select entity to id."""
if not pipeline_entity_id:
return None

View file

@ -1,10 +1,7 @@
"""Models for assist satellite."""
from dataclasses import dataclass
from enum import IntFlag, StrEnum
from homeassistant.components.assist_pipeline import PipelineStage
class AssistSatelliteState(StrEnum):
"""Valid states of an Assist satellite entity."""
@ -25,25 +22,5 @@ class AssistSatelliteState(StrEnum):
class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""
TRIGGER_PIPELINE = 1
"""Device supports remote triggering of a pipeline."""
@dataclass(frozen=True)
class PipelineRunConfig:
"""Configuration for a satellite pipeline run."""
start_stage: PipelineStage
"""Start stage of the pipeline to run."""
end_stage: PipelineStage
"""End stage of the pipeline to run."""
pipeline_entity_id: str | None = None
"""Id of the entity with which pipeline to run."""
announce_text: str | None = None
"""Text to announce using text-to-speech."""
announce_media_id: str | None = None
"""Media id to announce."""
ANNOUNCE = 1
"""Device supports remotely triggered announcements."""

View file

@ -0,0 +1,80 @@
"""Assist satellite Websocket API."""
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_component import EntityComponent
from .const import DOMAIN
from .entity import AssistSatelliteEntity
from .models import AssistSatelliteEntityFeature
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_announce)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/intercept_wake_word",
vol.Required("entity_id"): str,
}
)
@websocket_api.async_response
async def websocket_intercept_wake_word(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
return
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/announce",
vol.Required("entity_id"): str,
vol.Required(vol.Any("text", "media_id")): str,
}
)
@websocket_api.async_response
async def websocket_announce(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Announce text or a media id on the satellite."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(msg["id"], "entity_not_found", "Entity not found")
return
if (satellite.supported_features is None) or (
not (satellite.supported_features & AssistSatelliteEntityFeature.ANNOUNCE)
):
connection.send_message(
websocket_api.error_message(
msg["id"], ERR_NOT_SUPPORTED, "Satellite does not support announcements"
)
)
return
await satellite.async_announce(msg.get("text", ""), msg.get("media_id"))
connection.send_result(msg["id"], {})

View file

@ -33,7 +33,6 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.ulid import ulid_now
from .const import DOMAIN
from .entity import EsphomeAssistEntity
@ -102,9 +101,7 @@ class EsphomeAssistSatellite(
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
_attr_supported_features = (
assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
)
_attr_supported_features = assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
def __init__(
self,
@ -123,7 +120,6 @@ class EsphomeAssistSatellite(
self._is_running: bool = True
self._pipeline_task: asyncio.Task | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._pipeline_runs: dict[str, assist_satellite.PipelineRunConfig] = {}
self._tts_streaming_task: asyncio.Task | None = None
self._udp_server: VoiceAssistantUDPServer | None = None
@ -166,28 +162,13 @@ class EsphomeAssistSatellite(
)
)
# async def test() -> None:
# await asyncio.sleep(5)
# await self.async_announce("This is a test.")
self.config_entry.async_create_background_task(self.hass, test(), "test")
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
self._is_running = False
self._stop_pipeline()
async def async_trigger_pipeline_on_satellite(
self,
run_config: assist_satellite.PipelineRunConfig,
) -> None:
"""Triggers a remote pipeline run on the satellite."""
pipeline_run_id = ulid_now()
self._pipeline_runs[pipeline_run_id] = run_config
self.cli.trigger_voice_assistant_pipeline(
pipeline_run_id, run_config.announce_text, run_config.announce_media_id
)
_LOGGER.debug("Triggered remote pipeline run (id=%s)", pipeline_run_id)
async def _internal_async_announce(self, media_id: str) -> None:
self.cli.send_voice_assistant_announce(media_id)
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
@ -257,7 +238,6 @@ class EsphomeAssistSatellite(
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
pipeline_run_id: str | None,
) -> int | None:
"""Handle pipeline run request."""
# Clear audio queue
@ -265,7 +245,7 @@ class EsphomeAssistSatellite(
await self._audio_queue.get()
if self._tts_streaming_task is not None:
# Cancel any exiting TTS response
# Cancel current TTS response
self._tts_streaming_task.cancel()
self._tts_streaming_task = None
@ -290,21 +270,12 @@ class EsphomeAssistSatellite(
DOMAIN,
f"{self.entry_data.device_info.mac_address}-pipeline",
)
vad_sensitivity_id = ent_reg.async_get_entity_id(
vad_sensitivity_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT,
DOMAIN,
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
)
# Determine if this pipeline was triggered remotely or on-device
if (pipeline_run_id is not None) and (
(run_config := self._pipeline_runs.pop(pipeline_run_id)) is not None
):
# HA triggered pipeline
start_stage = run_config.start_stage
end_stage = run_config.end_stage
pipeline_entity_id = run_config.pipeline_entity_id or pipeline_entity_id
else:
# Device triggered pipeline (wake word, etc.)
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
@ -323,6 +294,7 @@ class EsphomeAssistSatellite(
start_stage=start_stage,
end_stage=end_stage,
pipeline_entity_id=pipeline_entity_id,
vad_sensitivity_entity_id=vad_sensitivity_entity_id,
wake_word_phrase=wake_word_phrase,
),
"esphome_assist_satellite_pipeline",
@ -391,8 +363,7 @@ class EsphomeAssistSatellite(
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
@ -417,9 +388,7 @@ class EsphomeAssistSatellite(
# sent for it to be played. This will overrun the
# device's buffer for very long audio, so using a media
# player is preferred.
samples_in_chunk = len(chunk) // (
sample_width * sample_channels
)
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
seconds_in_chunk = samples_in_chunk / sample_rate
await asyncio.sleep(seconds_in_chunk * 0.9)
except asyncio.CancelledError:
@ -433,7 +402,7 @@ class EsphomeAssistSatellite(
self.tts_response_finished()
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
"""Yields audio chunks from the queue until None."""
"""Yield audio chunks from the queue until None."""
while True:
chunk = await self._audio_queue.get()
if not chunk:
@ -442,12 +411,12 @@ class EsphomeAssistSatellite(
yield chunk
def _stop_pipeline(self) -> None:
"""Requests pipeline to be stopped."""
"""Request pipeline to be stopped."""
self._audio_queue.put_nowait(None)
_LOGGER.debug("Requested pipeline stop")
async def _start_udp_server(self) -> int:
"""Starts a UDP server on a random free port."""
"""Start a UDP server on a random free port."""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", 0)) # random free port
@ -466,7 +435,7 @@ class EsphomeAssistSatellite(
return cast(int, sock.getsockname()[1])
def _stop_udp_server(self) -> None:
"""Stops the UDP server if it's running."""
"""Stop the UDP server if it's running."""
if self._udp_server is None:
return
@ -488,6 +457,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
def __init__(
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
) -> None:
"""Initialize protocol."""
super().__init__(*args, **kwargs)
self._audio_queue = audio_queue

View file

@ -8,6 +8,7 @@ from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
@ -30,6 +31,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""
_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
def __init__(self) -> None:
"""Initialize the mock entity."""

View file

@ -0,0 +1,181 @@
"""Test the Assist Satellite websocket API."""
import asyncio
from collections.abc import AsyncIterable
from unittest.mock import ANY, patch
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineStage,
)
from homeassistant.components.assist_satellite import AssistSatelliteEntityFeature
from homeassistant.components.media_source import PlayMedia
from homeassistant.components.websocket_api import ERR_NOT_SUPPORTED
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from .conftest import MockAssistSatellite
from tests.typing import WebSocketGenerator
ENTITY_ID = "assist_satellite.test_entity"
async def audio_stream() -> AsyncIterable[bytes]:
"""Empty audio stream."""
yield b""
async def test_intercept_wake_word(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/intercept_wake_word command."""
client = await hass_ws_client(hass)
with (
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.validate",
return_value=None,
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text",
return_value=None,
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_recognize_intent",
return_value=None,
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_text_to_speech",
return_value=None,
),
patch.object(entity, "on_pipeline_event") as mock_on_pipeline_event,
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{"type": "assist_satellite/intercept_wake_word", "entity_id": ENTITY_ID}
)
# Wait for interception to start
while not entity.is_intercepting_wake_word:
await asyncio.sleep(0.01)
# Start a pipeline with a wake word
await entity._async_accept_pipeline_from_satellite(
audio_stream=audio_stream(),
start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS,
wake_word_phrase="test wake word",
)
# Verify that wake word was intercepted
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"wake_word_phrase": "test wake word"}
# Verify that only run end event was sent to pipeline
mock_on_pipeline_event.assert_called_once_with(
PipelineEvent(PipelineEventType.RUN_END, data=None, timestamp=ANY)
)
async def test_announce_not_supported(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/announce command with an entity that doesn't support announcements."""
client = await hass_ws_client(hass)
with patch.object(
entity, "_attr_supported_features", AssistSatelliteEntityFeature(0)
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{
"type": "assist_satellite/announce",
"entity_id": ENTITY_ID,
"media_id": "test media id",
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == ERR_NOT_SUPPORTED
async def test_announce_media_id(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/announce command with media id."""
client = await hass_ws_client(hass)
with (
patch.object(
entity, "_internal_async_announce"
) as mock_internal_async_announce,
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{
"type": "assist_satellite/announce",
"entity_id": ENTITY_ID,
"media_id": "test media id",
}
)
response = await client.receive_json()
assert response["success"]
# Verify media id was passed through
mock_internal_async_announce.assert_called_once_with("test media id")
async def test_announce_text(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test assist_satellite/announce command with text."""
client = await hass_ws_client(hass)
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="",
),
patch(
"homeassistant.components.assist_satellite.entity.media_source.async_resolve_media",
return_value=PlayMedia(url="test media id", mime_type=""),
),
patch(
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
return_value="test media id",
),
patch.object(
entity, "_internal_async_announce"
) as mock_internal_async_announce,
):
async with asyncio.timeout(1):
await client.send_json_auto_id(
{
"type": "assist_satellite/announce",
"entity_id": ENTITY_ID,
"text": "test text",
}
)
response = await client.receive_json()
assert response["success"]
# Verify media id was passed through
mock_internal_async_announce.assert_called_once_with("test media id")