Add websocket command to capture audio from a device (#103936)
* Add websocket command to capture audio from a device * Update homeassistant/components/assist_pipeline/pipeline.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Add device capture test * More tests * Add logbook * Remove unnecessary check * Remove seconds and make logbook message past tense --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
4536fb3541
commit
b3e247d5f0
8 changed files with 720 additions and 16 deletions
|
@ -9,7 +9,13 @@ from homeassistant.components import stt
|
|||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN
|
||||
from .const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DATA_CONFIG,
|
||||
DATA_LAST_WAKE_UP,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
)
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
|
@ -40,6 +46,7 @@ __all__ = (
|
|||
"PipelineEventType",
|
||||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
"EVENT_RECORDING",
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
|
|
|
@ -11,3 +11,5 @@ CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
|||
|
||||
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
||||
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
|
||||
|
||||
EVENT_RECORDING = f"{DOMAIN}_recording"
|
||||
|
|
39
homeassistant/components/assist_pipeline/logbook.py
Normal file
39
homeassistant/components/assist_pipeline/logbook.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
"""Describe assist_pipeline logbook events."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME
|
||||
from homeassistant.const import ATTR_DEVICE_ID
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
import homeassistant.helpers.device_registry as dr
|
||||
|
||||
from .const import DOMAIN, EVENT_RECORDING
|
||||
|
||||
|
||||
@callback
|
||||
def async_describe_events(
|
||||
hass: HomeAssistant,
|
||||
async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None],
|
||||
) -> None:
|
||||
"""Describe logbook events."""
|
||||
device_registry = dr.async_get(hass)
|
||||
|
||||
@callback
|
||||
def async_describe_logbook_event(event: Event) -> dict[str, str]:
|
||||
"""Describe logbook event."""
|
||||
device: dr.DeviceEntry | None = None
|
||||
device_name: str = "Unknown device"
|
||||
|
||||
device = device_registry.devices[event.data[ATTR_DEVICE_ID]]
|
||||
if device:
|
||||
device_name = device.name_by_user or device.name or "Unknown device"
|
||||
|
||||
message = f"{device_name} started recording audio"
|
||||
|
||||
return {
|
||||
LOGBOOK_ENTRY_NAME: device_name,
|
||||
LOGBOOK_ENTRY_MESSAGE: message,
|
||||
}
|
||||
|
||||
async_describe_event(DOMAIN, EVENT_RECORDING, async_describe_logbook_event)
|
|
@ -503,6 +503,9 @@ class PipelineRun:
|
|||
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
|
||||
"""Buffer used when splitting audio into chunks for audio processing"""
|
||||
|
||||
_device_id: str | None = None
|
||||
"""Optional device id set during run start."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
|
@ -554,7 +557,8 @@ class PipelineRun:
|
|||
|
||||
def start(self, device_id: str | None) -> None:
|
||||
"""Emit run start event."""
|
||||
self._start_debug_recording_thread(device_id)
|
||||
self._device_id = device_id
|
||||
self._start_debug_recording_thread()
|
||||
|
||||
data = {
|
||||
"pipeline": self.pipeline.id,
|
||||
|
@ -567,6 +571,9 @@ class PipelineRun:
|
|||
|
||||
async def end(self) -> None:
|
||||
"""Emit run end event."""
|
||||
# Signal end of stream to listeners
|
||||
self._capture_chunk(None)
|
||||
|
||||
# Stop the recording thread before emitting run-end.
|
||||
# This ensures that files are properly closed if the event handler reads them.
|
||||
await self._stop_debug_recording_thread()
|
||||
|
@ -746,9 +753,7 @@ class PipelineRun:
|
|||
if self.abort_wake_word_detection:
|
||||
raise WakeWordDetectionAborted
|
||||
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||
|
||||
self._capture_chunk(chunk.audio)
|
||||
yield chunk.audio, chunk.timestamp_ms
|
||||
|
||||
# Wake-word-detection occurs *after* the wake word was actually
|
||||
|
@ -870,8 +875,7 @@ class PipelineRun:
|
|||
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||
sent_vad_start = False
|
||||
async for chunk in audio_stream:
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||
self._capture_chunk(chunk.audio)
|
||||
|
||||
if stt_vad is not None:
|
||||
if not stt_vad.process(chunk_seconds, chunk.is_speech):
|
||||
|
@ -1057,7 +1061,28 @@ class PipelineRun:
|
|||
|
||||
return tts_media.url
|
||||
|
||||
def _start_debug_recording_thread(self, device_id: str | None) -> None:
|
||||
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
||||
"""Forward audio chunk to various capturing mechanisms."""
|
||||
if self.debug_recording_queue is not None:
|
||||
# Forward to debug WAV file recording
|
||||
self.debug_recording_queue.put_nowait(audio_bytes)
|
||||
|
||||
if self._device_id is None:
|
||||
return
|
||||
|
||||
# Forward to device audio capture
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
|
||||
if audio_queue is None:
|
||||
return
|
||||
|
||||
try:
|
||||
audio_queue.queue.put_nowait(audio_bytes)
|
||||
except asyncio.QueueFull:
|
||||
audio_queue.overflow = True
|
||||
_LOGGER.warning("Audio queue full for device %s", self._device_id)
|
||||
|
||||
def _start_debug_recording_thread(self) -> None:
|
||||
"""Start thread to record wake/stt audio if debug_recording_dir is set."""
|
||||
if self.debug_recording_thread is not None:
|
||||
# Already started
|
||||
|
@ -1068,7 +1093,7 @@ class PipelineRun:
|
|||
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
|
||||
CONF_DEBUG_RECORDING_DIR
|
||||
):
|
||||
if device_id is None:
|
||||
if self._device_id is None:
|
||||
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
||||
run_recording_dir = (
|
||||
Path(debug_recording_dir)
|
||||
|
@ -1079,7 +1104,7 @@ class PipelineRun:
|
|||
# <debug_recording_dir>/<device_id>/<pipeline.name>/<run.id>
|
||||
run_recording_dir = (
|
||||
Path(debug_recording_dir)
|
||||
/ device_id
|
||||
/ self._device_id
|
||||
/ self.pipeline.name
|
||||
/ str(time.monotonic_ns())
|
||||
)
|
||||
|
@ -1100,8 +1125,8 @@ class PipelineRun:
|
|||
# Not running
|
||||
return
|
||||
|
||||
# Signal thread to stop gracefully
|
||||
self.debug_recording_queue.put(None)
|
||||
# NOTE: Expecting a None to have been put in self.debug_recording_queue
|
||||
# in self.end() to signal the thread to stop.
|
||||
|
||||
# Wait until the thread has finished to ensure that files are fully written
|
||||
await self.hass.async_add_executor_job(self.debug_recording_thread.join)
|
||||
|
@ -1632,6 +1657,20 @@ class PipelineRuns:
|
|||
pipeline_run.abort_wake_word_detection = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceAudioQueue:
|
||||
"""Audio capture queue for a satellite device."""
|
||||
|
||||
queue: asyncio.Queue[bytes | None]
|
||||
"""Queue of audio chunks (None = stop signal)"""
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
"""Unique id to ensure the correct audio queue is cleaned up in websocket API."""
|
||||
|
||||
overflow: bool = False
|
||||
"""Flag to be set if audio samples were dropped because the queue was full."""
|
||||
|
||||
|
||||
class PipelineData:
|
||||
"""Store and debug data stored in hass.data."""
|
||||
|
||||
|
@ -1641,6 +1680,7 @@ class PipelineData:
|
|||
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
|
||||
self.pipeline_devices: set[str] = set()
|
||||
self.pipeline_runs = PipelineRuns(pipeline_store)
|
||||
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -3,22 +3,31 @@ import asyncio
|
|||
|
||||
# Suppressing disable=deprecated-module is needed for Python 3.11
|
||||
import audioop # pylint: disable=deprecated-module
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any
|
||||
import math
|
||||
from typing import Any, Final
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
|
||||
from .const import (
|
||||
DEFAULT_PIPELINE_TIMEOUT,
|
||||
DEFAULT_WAKE_WORD_TIMEOUT,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
)
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
DeviceAudioQueue,
|
||||
PipelineData,
|
||||
PipelineError,
|
||||
PipelineEvent,
|
||||
|
@ -32,6 +41,11 @@ from .pipeline import (
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CAPTURE_RATE: Final = 16000
|
||||
CAPTURE_WIDTH: Final = 2
|
||||
CAPTURE_CHANNELS: Final = 1
|
||||
MAX_CAPTURE_TIMEOUT: Final = 60.0
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
|
@ -40,6 +54,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
websocket_api.async_register_command(hass, websocket_list_languages)
|
||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||
websocket_api.async_register_command(hass, websocket_get_run)
|
||||
websocket_api.async_register_command(hass, websocket_device_capture)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
|
@ -371,3 +386,101 @@ async def websocket_list_languages(
|
|||
else pipeline_languages
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/device/capture",
|
||||
vol.Required("device_id"): str,
|
||||
vol.Required("timeout"): vol.All(
|
||||
# 0 < timeout <= MAX_CAPTURE_TIMEOUT
|
||||
vol.Coerce(float),
|
||||
vol.Range(min=0, min_included=False, max=MAX_CAPTURE_TIMEOUT),
|
||||
),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_capture(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Capture raw audio from a satellite device and forward to client."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
device_id = msg["device_id"]
|
||||
|
||||
# Number of seconds to record audio in wall clock time
|
||||
timeout_seconds = msg["timeout"]
|
||||
|
||||
# We don't know the chunk size, so the upper bound is calculated assuming a
|
||||
# single sample (16 bits) per queue item.
|
||||
max_queue_items = (
|
||||
# +1 for None to signal end
|
||||
int(math.ceil(timeout_seconds * CAPTURE_RATE))
|
||||
+ 1
|
||||
)
|
||||
|
||||
audio_queue = DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items))
|
||||
|
||||
# Running simultaneous captures for a single device will not work by design.
|
||||
# The new capture will cause the old capture to stop.
|
||||
if (
|
||||
old_audio_queue := pipeline_data.device_audio_queues.pop(device_id, None)
|
||||
) is not None:
|
||||
with contextlib.suppress(asyncio.QueueFull):
|
||||
# Signal other websocket command that we're taking over
|
||||
old_audio_queue.queue.put_nowait(None)
|
||||
|
||||
# Only one client can be capturing audio at a time
|
||||
pipeline_data.device_audio_queues[device_id] = audio_queue
|
||||
|
||||
def clean_up_queue() -> None:
|
||||
# Clean up our audio queue
|
||||
maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id)
|
||||
if (maybe_audio_queue is not None) and (maybe_audio_queue.id == audio_queue.id):
|
||||
# Only pop if this is our queue
|
||||
pipeline_data.device_audio_queues.pop(device_id)
|
||||
|
||||
# Unsubscribe cleans up queue
|
||||
connection.subscriptions[msg["id"]] = clean_up_queue
|
||||
|
||||
# Audio will follow as events
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
# Record to logbook
|
||||
hass.bus.async_fire(
|
||||
EVENT_RECORDING,
|
||||
{
|
||||
ATTR_DEVICE_ID: device_id,
|
||||
ATTR_SECONDS: timeout_seconds,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
with contextlib.suppress(asyncio.TimeoutError):
|
||||
async with asyncio.timeout(timeout_seconds):
|
||||
while True:
|
||||
# Send audio chunks encoded as base64
|
||||
audio_bytes = await audio_queue.queue.get()
|
||||
if audio_bytes is None:
|
||||
# Signal to stop
|
||||
break
|
||||
|
||||
connection.send_event(
|
||||
msg["id"],
|
||||
{
|
||||
"type": "audio",
|
||||
"rate": CAPTURE_RATE, # hertz
|
||||
"width": CAPTURE_WIDTH, # bytes
|
||||
"channels": CAPTURE_CHANNELS,
|
||||
"audio": base64.b64encode(audio_bytes).decode("ascii"),
|
||||
},
|
||||
)
|
||||
|
||||
# Capture has ended
|
||||
connection.send_event(
|
||||
msg["id"], {"type": "end", "overflow": audio_queue.overflow}
|
||||
)
|
||||
finally:
|
||||
clean_up_queue()
|
||||
|
|
|
@ -487,6 +487,119 @@
|
|||
# name: test_audio_pipeline_with_wake_word_timeout.3
|
||||
None
|
||||
# ---
|
||||
# name: test_device_capture
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture.2
|
||||
None
|
||||
# ---
|
||||
# name: test_device_capture_override
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_override.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_override.2
|
||||
dict({
|
||||
'audio': 'Y2h1bmsx',
|
||||
'channels': 1,
|
||||
'rate': 16000,
|
||||
'type': 'audio',
|
||||
'width': 2,
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_override.3
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_override.4
|
||||
None
|
||||
# ---
|
||||
# name: test_device_capture_override.5
|
||||
dict({
|
||||
'overflow': False,
|
||||
'type': 'end',
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_queue_full
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_queue_full.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_queue_full.2
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_device_capture_queue_full.3
|
||||
None
|
||||
# ---
|
||||
# name: test_device_capture_queue_full.4
|
||||
dict({
|
||||
'overflow': True,
|
||||
'type': 'end',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
|
42
tests/components/assist_pipeline/test_logbook.py
Normal file
42
tests/components/assist_pipeline/test_logbook.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
"""The tests for assist_pipeline logbook."""
|
||||
from homeassistant.components import assist_pipeline, logbook
|
||||
from homeassistant.const import ATTR_DEVICE_ID
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.logbook.common import MockRow, mock_humanify
|
||||
|
||||
|
||||
async def test_recording_event(
|
||||
hass: HomeAssistant, init_components, device_registry: dr.DeviceRegistry
|
||||
) -> None:
|
||||
"""Test recording event."""
|
||||
hass.config.components.add("recorder")
|
||||
assert await async_setup_component(hass, "logbook", {})
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
satellite_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "satellite-1234")},
|
||||
)
|
||||
|
||||
device_registry.async_update_device(satellite_device.id, name="My Satellite")
|
||||
event = mock_humanify(
|
||||
hass,
|
||||
[
|
||||
MockRow(
|
||||
assist_pipeline.EVENT_RECORDING,
|
||||
{ATTR_DEVICE_ID: satellite_device.id},
|
||||
),
|
||||
],
|
||||
)[0]
|
||||
|
||||
assert event[logbook.LOGBOOK_ENTRY_NAME] == "My Satellite"
|
||||
assert event[logbook.LOGBOOK_ENTRY_DOMAIN] == assist_pipeline.DOMAIN
|
||||
assert (
|
||||
event[logbook.LOGBOOK_ENTRY_MESSAGE] == "My Satellite started recording audio"
|
||||
)
|
|
@ -1,16 +1,23 @@
|
|||
"""Websocket tests for Voice Assistant integration."""
|
||||
import asyncio
|
||||
import base64
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
DeviceAudioQueue,
|
||||
Pipeline,
|
||||
PipelineData,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
|
||||
from .conftest import MockWakeWordEntity, MockWakeWordEntity2
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
|
@ -2104,3 +2111,344 @@ async def test_wake_word_cooldown_different_entities(
|
|||
|
||||
# Wake words should be the same
|
||||
assert ww_id_1 == ww_id_2
|
||||
|
||||
|
||||
async def test_device_capture(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test audio capture from a satellite device."""
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
satellite_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "satellite-1234")},
|
||||
)
|
||||
|
||||
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
|
||||
|
||||
# Start capture
|
||||
client_capture = await hass_ws_client(hass)
|
||||
await client_capture.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/device/capture",
|
||||
"timeout": 30,
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_capture.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# Run pipeline
|
||||
client_pipeline = await hass_ws_client(hass)
|
||||
await client_pipeline.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "stt",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"no_vad": True,
|
||||
"no_chunking": True,
|
||||
},
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
|
||||
# stt
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
for audio_chunk in audio_chunks:
|
||||
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk)
|
||||
|
||||
# End of audio stream
|
||||
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
|
||||
# run end
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Verify capture
|
||||
events = []
|
||||
async with asyncio.timeout(1):
|
||||
while True:
|
||||
msg = await client_capture.receive_json()
|
||||
assert msg["type"] == "event"
|
||||
event_data = msg["event"]
|
||||
events.append(event_data)
|
||||
if event_data["type"] == "end":
|
||||
break
|
||||
|
||||
assert len(events) == len(audio_chunks) + 1
|
||||
|
||||
# Verify audio chunks
|
||||
for i, audio_chunk in enumerate(audio_chunks):
|
||||
assert events[i]["type"] == "audio"
|
||||
assert events[i]["rate"] == 16000
|
||||
assert events[i]["width"] == 2
|
||||
assert events[i]["channels"] == 1
|
||||
|
||||
# Audio is base64 encoded
|
||||
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
||||
|
||||
# Last event is the end
|
||||
assert events[-1]["type"] == "end"
|
||||
|
||||
|
||||
async def test_device_capture_override(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test overriding an existing audio capture from a satellite device."""
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
satellite_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "satellite-1234")},
|
||||
)
|
||||
|
||||
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
|
||||
|
||||
# Start first capture
|
||||
client_capture_1 = await hass_ws_client(hass)
|
||||
await client_capture_1.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/device/capture",
|
||||
"timeout": 30,
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_capture_1.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# Run pipeline
|
||||
client_pipeline = await hass_ws_client(hass)
|
||||
await client_pipeline.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "stt",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"no_vad": True,
|
||||
"no_chunking": True,
|
||||
},
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
|
||||
# stt
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Send first audio chunk
|
||||
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunks[0])
|
||||
|
||||
# Verify first capture
|
||||
msg = await client_capture_1.receive_json()
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == snapshot
|
||||
assert msg["event"]["audio"] == base64.b64encode(audio_chunks[0]).decode("ascii")
|
||||
|
||||
# Start a new capture
|
||||
client_capture_2 = await hass_ws_client(hass)
|
||||
await client_capture_2.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/device/capture",
|
||||
"timeout": 30,
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result (capture 2)
|
||||
msg = await client_capture_2.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# Send remaining audio chunks
|
||||
for audio_chunk in audio_chunks[1:]:
|
||||
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk)
|
||||
|
||||
# End of audio stream
|
||||
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# run end
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Verify that first capture ended with no more audio
|
||||
msg = await client_capture_1.receive_json()
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == snapshot
|
||||
assert msg["event"]["type"] == "end"
|
||||
|
||||
# Verify that the second capture got the remaining audio
|
||||
events = []
|
||||
async with asyncio.timeout(1):
|
||||
while True:
|
||||
msg = await client_capture_2.receive_json()
|
||||
assert msg["type"] == "event"
|
||||
event_data = msg["event"]
|
||||
events.append(event_data)
|
||||
if event_data["type"] == "end":
|
||||
break
|
||||
|
||||
# -1 since first audio chunk went to the first capture
|
||||
assert len(events) == len(audio_chunks)
|
||||
|
||||
# Verify all but first audio chunk
|
||||
for i, audio_chunk in enumerate(audio_chunks[1:]):
|
||||
assert events[i]["type"] == "audio"
|
||||
assert events[i]["rate"] == 16000
|
||||
assert events[i]["width"] == 2
|
||||
assert events[i]["channels"] == 1
|
||||
|
||||
# Audio is base64 encoded
|
||||
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
||||
|
||||
# Last event is the end
|
||||
assert events[-1]["type"] == "end"
|
||||
|
||||
|
||||
async def test_device_capture_queue_full(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test audio capture from a satellite device when the recording queue fills up."""
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
satellite_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "satellite-1234")},
|
||||
)
|
||||
|
||||
class FakeQueue(asyncio.Queue):
|
||||
"""Queue that reports full for anything but None."""
|
||||
|
||||
def put_nowait(self, item):
|
||||
if item is not None:
|
||||
raise asyncio.QueueFull()
|
||||
|
||||
super().put_nowait(item)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.websocket_api.DeviceAudioQueue"
|
||||
) as mock:
|
||||
mock.return_value = DeviceAudioQueue(queue=FakeQueue())
|
||||
|
||||
# Start capture
|
||||
client_capture = await hass_ws_client(hass)
|
||||
await client_capture.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/device/capture",
|
||||
"timeout": 30,
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_capture.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# Run pipeline
|
||||
client_pipeline = await hass_ws_client(hass)
|
||||
await client_pipeline.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "stt",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"no_vad": True,
|
||||
"no_chunking": True,
|
||||
},
|
||||
"device_id": satellite_device.id,
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
|
||||
# stt
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Single sample will "overflow" the queue
|
||||
await client_pipeline.send_bytes(bytes([handler_id, 0, 0]))
|
||||
|
||||
# End of audio stream
|
||||
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client_pipeline.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Queue should have been overflowed
|
||||
async with asyncio.timeout(1):
|
||||
msg = await client_capture.receive_json()
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == snapshot
|
||||
assert msg["event"]["type"] == "end"
|
||||
assert msg["event"]["overflow"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue