Add websocket command to capture audio from a device (#103936)

* Add websocket command to capture audio from a device

* Update homeassistant/components/assist_pipeline/pipeline.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Add device capture test

* More tests

* Add logbook

* Remove unnecessary check

* Remove seconds and make logbook message past tense

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-11-16 10:28:06 -06:00 committed by GitHub
parent 4536fb3541
commit b3e247d5f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 720 additions and 16 deletions

View file

@ -9,7 +9,13 @@ from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN
from .const import (
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
)
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
@ -40,6 +46,7 @@ __all__ = (
"PipelineEventType",
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
)
CONFIG_SCHEMA = vol.Schema(

View file

@ -11,3 +11,5 @@ CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
EVENT_RECORDING = f"{DOMAIN}_recording"

View file

@ -0,0 +1,39 @@
"""Describe assist_pipeline logbook events."""
from __future__ import annotations
from collections.abc import Callable
from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME
from homeassistant.const import ATTR_DEVICE_ID
from homeassistant.core import Event, HomeAssistant, callback
import homeassistant.helpers.device_registry as dr
from .const import DOMAIN, EVENT_RECORDING
@callback
def async_describe_events(
hass: HomeAssistant,
async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None],
) -> None:
"""Describe logbook events."""
device_registry = dr.async_get(hass)
@callback
def async_describe_logbook_event(event: Event) -> dict[str, str]:
"""Describe logbook event."""
device: dr.DeviceEntry | None = None
device_name: str = "Unknown device"
device = device_registry.devices[event.data[ATTR_DEVICE_ID]]
if device:
device_name = device.name_by_user or device.name or "Unknown device"
message = f"{device_name} started recording audio"
return {
LOGBOOK_ENTRY_NAME: device_name,
LOGBOOK_ENTRY_MESSAGE: message,
}
async_describe_event(DOMAIN, EVENT_RECORDING, async_describe_logbook_event)

View file

@ -503,6 +503,9 @@ class PipelineRun:
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
"""Buffer used when splitting audio into chunks for audio processing"""
_device_id: str | None = None
"""Optional device id set during run start."""
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
@ -554,7 +557,8 @@ class PipelineRun:
def start(self, device_id: str | None) -> None:
"""Emit run start event."""
self._start_debug_recording_thread(device_id)
self._device_id = device_id
self._start_debug_recording_thread()
data = {
"pipeline": self.pipeline.id,
@ -567,6 +571,9 @@ class PipelineRun:
async def end(self) -> None:
"""Emit run end event."""
# Signal end of stream to listeners
self._capture_chunk(None)
# Stop the recording thread before emitting run-end.
# This ensures that files are properly closed if the event handler reads them.
await self._stop_debug_recording_thread()
@ -746,9 +753,7 @@ class PipelineRun:
if self.abort_wake_word_detection:
raise WakeWordDetectionAborted
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk.audio)
self._capture_chunk(chunk.audio)
yield chunk.audio, chunk.timestamp_ms
# Wake-word-detection occurs *after* the wake word was actually
@ -870,8 +875,7 @@ class PipelineRun:
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
sent_vad_start = False
async for chunk in audio_stream:
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk.audio)
self._capture_chunk(chunk.audio)
if stt_vad is not None:
if not stt_vad.process(chunk_seconds, chunk.is_speech):
@ -1057,7 +1061,28 @@ class PipelineRun:
return tts_media.url
def _start_debug_recording_thread(self, device_id: str | None) -> None:
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
"""Forward audio chunk to various capturing mechanisms."""
if self.debug_recording_queue is not None:
# Forward to debug WAV file recording
self.debug_recording_queue.put_nowait(audio_bytes)
if self._device_id is None:
return
# Forward to device audio capture
pipeline_data: PipelineData = self.hass.data[DOMAIN]
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
if audio_queue is None:
return
try:
audio_queue.queue.put_nowait(audio_bytes)
except asyncio.QueueFull:
audio_queue.overflow = True
_LOGGER.warning("Audio queue full for device %s", self._device_id)
def _start_debug_recording_thread(self) -> None:
"""Start thread to record wake/stt audio if debug_recording_dir is set."""
if self.debug_recording_thread is not None:
# Already started
@ -1068,7 +1093,7 @@ class PipelineRun:
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
CONF_DEBUG_RECORDING_DIR
):
if device_id is None:
if self._device_id is None:
# <debug_recording_dir>/<pipeline.name>/<run.id>
run_recording_dir = (
Path(debug_recording_dir)
@ -1079,7 +1104,7 @@ class PipelineRun:
# <debug_recording_dir>/<device_id>/<pipeline.name>/<run.id>
run_recording_dir = (
Path(debug_recording_dir)
/ device_id
/ self._device_id
/ self.pipeline.name
/ str(time.monotonic_ns())
)
@ -1100,8 +1125,8 @@ class PipelineRun:
# Not running
return
# Signal thread to stop gracefully
self.debug_recording_queue.put(None)
# NOTE: Expecting a None to have been put in self.debug_recording_queue
# in self.end() to signal the thread to stop.
# Wait until the thread has finished to ensure that files are fully written
await self.hass.async_add_executor_job(self.debug_recording_thread.join)
@ -1632,6 +1657,20 @@ class PipelineRuns:
pipeline_run.abort_wake_word_detection = True
@dataclass
class DeviceAudioQueue:
"""Audio capture queue for a satellite device."""
queue: asyncio.Queue[bytes | None]
"""Queue of audio chunks (None = stop signal)"""
id: str = field(default_factory=ulid_util.ulid)
"""Unique id to ensure the correct audio queue is cleaned up in websocket API."""
overflow: bool = False
"""Flag to be set if audio samples were dropped because the queue was full."""
class PipelineData:
"""Store and debug data stored in hass.data."""
@ -1641,6 +1680,7 @@ class PipelineData:
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
self.pipeline_devices: set[str] = set()
self.pipeline_runs = PipelineRuns(pipeline_store)
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}
@dataclass

View file

@ -3,22 +3,31 @@ import asyncio
# Suppressing disable=deprecated-module is needed for Python 3.11
import audioop # pylint: disable=deprecated-module
import base64
from collections.abc import AsyncGenerator, Callable
import contextlib
import logging
from typing import Any
import math
from typing import Any, Final
import voluptuous as vol
from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import MATCH_ALL
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util
from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
from .const import (
DEFAULT_PIPELINE_TIMEOUT,
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
EVENT_RECORDING,
)
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
DeviceAudioQueue,
PipelineData,
PipelineError,
PipelineEvent,
@ -32,6 +41,11 @@ from .pipeline import (
_LOGGER = logging.getLogger(__name__)
CAPTURE_RATE: Final = 16000
CAPTURE_WIDTH: Final = 2
CAPTURE_CHANNELS: Final = 1
MAX_CAPTURE_TIMEOUT: Final = 60.0
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
@ -40,6 +54,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_list_languages)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run)
websocket_api.async_register_command(hass, websocket_device_capture)
@websocket_api.websocket_command(
@ -371,3 +386,101 @@ async def websocket_list_languages(
else pipeline_languages
},
)
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/device/capture",
vol.Required("device_id"): str,
vol.Required("timeout"): vol.All(
# 0 < timeout <= MAX_CAPTURE_TIMEOUT
vol.Coerce(float),
vol.Range(min=0, min_included=False, max=MAX_CAPTURE_TIMEOUT),
),
}
)
@websocket_api.async_response
async def websocket_device_capture(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Capture raw audio from a satellite device and forward to client."""
pipeline_data: PipelineData = hass.data[DOMAIN]
device_id = msg["device_id"]
# Number of seconds to record audio in wall clock time
timeout_seconds = msg["timeout"]
# We don't know the chunk size, so the upper bound is calculated assuming a
# single sample (16 bits) per queue item.
max_queue_items = (
# +1 for None to signal end
int(math.ceil(timeout_seconds * CAPTURE_RATE))
+ 1
)
audio_queue = DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items))
# Running simultaneous captures for a single device will not work by design.
# The new capture will cause the old capture to stop.
if (
old_audio_queue := pipeline_data.device_audio_queues.pop(device_id, None)
) is not None:
with contextlib.suppress(asyncio.QueueFull):
# Signal other websocket command that we're taking over
old_audio_queue.queue.put_nowait(None)
# Only one client can be capturing audio at a time
pipeline_data.device_audio_queues[device_id] = audio_queue
def clean_up_queue() -> None:
# Clean up our audio queue
maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id)
if (maybe_audio_queue is not None) and (maybe_audio_queue.id == audio_queue.id):
# Only pop if this is our queue
pipeline_data.device_audio_queues.pop(device_id)
# Unsubscribe cleans up queue
connection.subscriptions[msg["id"]] = clean_up_queue
# Audio will follow as events
connection.send_result(msg["id"])
# Record to logbook
hass.bus.async_fire(
EVENT_RECORDING,
{
ATTR_DEVICE_ID: device_id,
ATTR_SECONDS: timeout_seconds,
},
)
try:
with contextlib.suppress(asyncio.TimeoutError):
async with asyncio.timeout(timeout_seconds):
while True:
# Send audio chunks encoded as base64
audio_bytes = await audio_queue.queue.get()
if audio_bytes is None:
# Signal to stop
break
connection.send_event(
msg["id"],
{
"type": "audio",
"rate": CAPTURE_RATE, # hertz
"width": CAPTURE_WIDTH, # bytes
"channels": CAPTURE_CHANNELS,
"audio": base64.b64encode(audio_bytes).decode("ascii"),
},
)
# Capture has ended
connection.send_event(
msg["id"], {"type": "end", "overflow": audio_queue.overflow}
)
finally:
clean_up_queue()

View file

@ -487,6 +487,119 @@
# name: test_audio_pipeline_with_wake_word_timeout.3
None
# ---
# name: test_device_capture
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_device_capture.1
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_device_capture.2
None
# ---
# name: test_device_capture_override
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_device_capture_override.1
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_device_capture_override.2
dict({
'audio': 'Y2h1bmsx',
'channels': 1,
'rate': 16000,
'type': 'audio',
'width': 2,
})
# ---
# name: test_device_capture_override.3
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_device_capture_override.4
None
# ---
# name: test_device_capture_override.5
dict({
'overflow': False,
'type': 'end',
})
# ---
# name: test_device_capture_queue_full
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_device_capture_queue_full.1
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_device_capture_queue_full.2
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_device_capture_queue_full.3
None
# ---
# name: test_device_capture_queue_full.4
dict({
'overflow': True,
'type': 'end',
})
# ---
# name: test_intent_failed
dict({
'language': 'en',

View file

@ -0,0 +1,42 @@
"""The tests for assist_pipeline logbook."""
from homeassistant.components import assist_pipeline, logbook
from homeassistant.const import ATTR_DEVICE_ID
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.components.logbook.common import MockRow, mock_humanify
async def test_recording_event(
hass: HomeAssistant, init_components, device_registry: dr.DeviceRegistry
) -> None:
"""Test recording event."""
hass.config.components.add("recorder")
assert await async_setup_component(hass, "logbook", {})
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "satellite-1234")},
)
device_registry.async_update_device(satellite_device.id, name="My Satellite")
event = mock_humanify(
hass,
[
MockRow(
assist_pipeline.EVENT_RECORDING,
{ATTR_DEVICE_ID: satellite_device.id},
),
],
)[0]
assert event[logbook.LOGBOOK_ENTRY_NAME] == "My Satellite"
assert event[logbook.LOGBOOK_ENTRY_DOMAIN] == assist_pipeline.DOMAIN
assert (
event[logbook.LOGBOOK_ENTRY_MESSAGE] == "My Satellite started recording audio"
)

View file

@ -1,16 +1,23 @@
"""Websocket tests for Voice Assistant integration."""
import asyncio
import base64
from unittest.mock import ANY, patch
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
from homeassistant.components.assist_pipeline.pipeline import (
DeviceAudioQueue,
Pipeline,
PipelineData,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from .conftest import MockWakeWordEntity, MockWakeWordEntity2
from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@ -2104,3 +2111,344 @@ async def test_wake_word_cooldown_different_entities(
# Wake words should be the same
assert ww_id_1 == ww_id_2
async def test_device_capture(
hass: HomeAssistant,
init_components,
hass_ws_client: WebSocketGenerator,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test audio capture from a satellite device."""
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "satellite-1234")},
)
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
# Start capture
client_capture = await hass_ws_client(hass)
await client_capture.send_json_auto_id(
{
"type": "assist_pipeline/device/capture",
"timeout": 30,
"device_id": satellite_device.id,
}
)
# result
msg = await client_capture.receive_json()
assert msg["success"]
# Run pipeline
client_pipeline = await hass_ws_client(hass)
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "stt",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"device_id": satellite_device.id,
}
)
# result
msg = await client_pipeline.receive_json()
assert msg["success"]
# run start
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
# stt
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
for audio_chunk in audio_chunks:
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk)
# End of audio stream
await client_pipeline.send_bytes(bytes([handler_id]))
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "stt-end"
# run end
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
# Verify capture
events = []
async with asyncio.timeout(1):
while True:
msg = await client_capture.receive_json()
assert msg["type"] == "event"
event_data = msg["event"]
events.append(event_data)
if event_data["type"] == "end":
break
assert len(events) == len(audio_chunks) + 1
# Verify audio chunks
for i, audio_chunk in enumerate(audio_chunks):
assert events[i]["type"] == "audio"
assert events[i]["rate"] == 16000
assert events[i]["width"] == 2
assert events[i]["channels"] == 1
# Audio is base64 encoded
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
# Last event is the end
assert events[-1]["type"] == "end"
async def test_device_capture_override(
hass: HomeAssistant,
init_components,
hass_ws_client: WebSocketGenerator,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test overriding an existing audio capture from a satellite device."""
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "satellite-1234")},
)
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
# Start first capture
client_capture_1 = await hass_ws_client(hass)
await client_capture_1.send_json_auto_id(
{
"type": "assist_pipeline/device/capture",
"timeout": 30,
"device_id": satellite_device.id,
}
)
# result
msg = await client_capture_1.receive_json()
assert msg["success"]
# Run pipeline
client_pipeline = await hass_ws_client(hass)
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "stt",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"device_id": satellite_device.id,
}
)
# result
msg = await client_pipeline.receive_json()
assert msg["success"]
# run start
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
# stt
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
# Send first audio chunk
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunks[0])
# Verify first capture
msg = await client_capture_1.receive_json()
assert msg["type"] == "event"
assert msg["event"] == snapshot
assert msg["event"]["audio"] == base64.b64encode(audio_chunks[0]).decode("ascii")
# Start a new capture
client_capture_2 = await hass_ws_client(hass)
await client_capture_2.send_json_auto_id(
{
"type": "assist_pipeline/device/capture",
"timeout": 30,
"device_id": satellite_device.id,
}
)
# result (capture 2)
msg = await client_capture_2.receive_json()
assert msg["success"]
# Send remaining audio chunks
for audio_chunk in audio_chunks[1:]:
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk)
# End of audio stream
await client_pipeline.send_bytes(bytes([handler_id]))
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
# run end
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
# Verify that first capture ended with no more audio
msg = await client_capture_1.receive_json()
assert msg["type"] == "event"
assert msg["event"] == snapshot
assert msg["event"]["type"] == "end"
# Verify that the second capture got the remaining audio
events = []
async with asyncio.timeout(1):
while True:
msg = await client_capture_2.receive_json()
assert msg["type"] == "event"
event_data = msg["event"]
events.append(event_data)
if event_data["type"] == "end":
break
# -1 since first audio chunk went to the first capture
assert len(events) == len(audio_chunks)
# Verify all but first audio chunk
for i, audio_chunk in enumerate(audio_chunks[1:]):
assert events[i]["type"] == "audio"
assert events[i]["rate"] == 16000
assert events[i]["width"] == 2
assert events[i]["channels"] == 1
# Audio is base64 encoded
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
# Last event is the end
assert events[-1]["type"] == "end"
async def test_device_capture_queue_full(
hass: HomeAssistant,
init_components,
hass_ws_client: WebSocketGenerator,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test audio capture from a satellite device when the recording queue fills up."""
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "satellite-1234")},
)
class FakeQueue(asyncio.Queue):
"""Queue that reports full for anything but None."""
def put_nowait(self, item):
if item is not None:
raise asyncio.QueueFull()
super().put_nowait(item)
with patch(
"homeassistant.components.assist_pipeline.websocket_api.DeviceAudioQueue"
) as mock:
mock.return_value = DeviceAudioQueue(queue=FakeQueue())
# Start capture
client_capture = await hass_ws_client(hass)
await client_capture.send_json_auto_id(
{
"type": "assist_pipeline/device/capture",
"timeout": 30,
"device_id": satellite_device.id,
}
)
# result
msg = await client_capture.receive_json()
assert msg["success"]
# Run pipeline
client_pipeline = await hass_ws_client(hass)
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "stt",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"device_id": satellite_device.id,
}
)
# result
msg = await client_pipeline.receive_json()
assert msg["success"]
# run start
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
# stt
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
# Single sample will "overflow" the queue
await client_pipeline.send_bytes(bytes([handler_id, 0, 0]))
# End of audio stream
await client_pipeline.send_bytes(bytes([handler_id]))
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
msg = await client_pipeline.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
# Queue should have been overflowed
async with asyncio.timeout(1):
msg = await client_capture.receive_json()
assert msg["type"] == "event"
assert msg["event"] == snapshot
assert msg["event"]["type"] == "end"
assert msg["event"]["overflow"]