Add websocket command to capture audio from a device (#103936)
* Add websocket command to capture audio from a device * Update homeassistant/components/assist_pipeline/pipeline.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Add device capture test * More tests * Add logbook * Remove unnecessary check * Remove seconds and make logbook message past tense --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
4536fb3541
commit
b3e247d5f0
8 changed files with 720 additions and 16 deletions
|
@ -9,7 +9,13 @@ from homeassistant.components import stt
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN
|
from .const import (
|
||||||
|
CONF_DEBUG_RECORDING_DIR,
|
||||||
|
DATA_CONFIG,
|
||||||
|
DATA_LAST_WAKE_UP,
|
||||||
|
DOMAIN,
|
||||||
|
EVENT_RECORDING,
|
||||||
|
)
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
AudioSettings,
|
AudioSettings,
|
||||||
|
@ -40,6 +46,7 @@ __all__ = (
|
||||||
"PipelineEventType",
|
"PipelineEventType",
|
||||||
"PipelineNotFound",
|
"PipelineNotFound",
|
||||||
"WakeWordSettings",
|
"WakeWordSettings",
|
||||||
|
"EVENT_RECORDING",
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema(
|
CONFIG_SCHEMA = vol.Schema(
|
||||||
|
|
|
@ -11,3 +11,5 @@ CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
||||||
|
|
||||||
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
||||||
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
|
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
|
||||||
|
|
||||||
|
EVENT_RECORDING = f"{DOMAIN}_recording"
|
||||||
|
|
39
homeassistant/components/assist_pipeline/logbook.py
Normal file
39
homeassistant/components/assist_pipeline/logbook.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
"""Describe assist_pipeline logbook events."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME
|
||||||
|
from homeassistant.const import ATTR_DEVICE_ID
|
||||||
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
|
import homeassistant.helpers.device_registry as dr
|
||||||
|
|
||||||
|
from .const import DOMAIN, EVENT_RECORDING
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_describe_events(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None],
|
||||||
|
) -> None:
|
||||||
|
"""Describe logbook events."""
|
||||||
|
device_registry = dr.async_get(hass)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_describe_logbook_event(event: Event) -> dict[str, str]:
|
||||||
|
"""Describe logbook event."""
|
||||||
|
device: dr.DeviceEntry | None = None
|
||||||
|
device_name: str = "Unknown device"
|
||||||
|
|
||||||
|
device = device_registry.devices[event.data[ATTR_DEVICE_ID]]
|
||||||
|
if device:
|
||||||
|
device_name = device.name_by_user or device.name or "Unknown device"
|
||||||
|
|
||||||
|
message = f"{device_name} started recording audio"
|
||||||
|
|
||||||
|
return {
|
||||||
|
LOGBOOK_ENTRY_NAME: device_name,
|
||||||
|
LOGBOOK_ENTRY_MESSAGE: message,
|
||||||
|
}
|
||||||
|
|
||||||
|
async_describe_event(DOMAIN, EVENT_RECORDING, async_describe_logbook_event)
|
|
@ -503,6 +503,9 @@ class PipelineRun:
|
||||||
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
|
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
|
||||||
"""Buffer used when splitting audio into chunks for audio processing"""
|
"""Buffer used when splitting audio into chunks for audio processing"""
|
||||||
|
|
||||||
|
_device_id: str | None = None
|
||||||
|
"""Optional device id set during run start."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
self.language = self.pipeline.language or self.hass.config.language
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
|
@ -554,7 +557,8 @@ class PipelineRun:
|
||||||
|
|
||||||
def start(self, device_id: str | None) -> None:
|
def start(self, device_id: str | None) -> None:
|
||||||
"""Emit run start event."""
|
"""Emit run start event."""
|
||||||
self._start_debug_recording_thread(device_id)
|
self._device_id = device_id
|
||||||
|
self._start_debug_recording_thread()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"pipeline": self.pipeline.id,
|
"pipeline": self.pipeline.id,
|
||||||
|
@ -567,6 +571,9 @@ class PipelineRun:
|
||||||
|
|
||||||
async def end(self) -> None:
|
async def end(self) -> None:
|
||||||
"""Emit run end event."""
|
"""Emit run end event."""
|
||||||
|
# Signal end of stream to listeners
|
||||||
|
self._capture_chunk(None)
|
||||||
|
|
||||||
# Stop the recording thread before emitting run-end.
|
# Stop the recording thread before emitting run-end.
|
||||||
# This ensures that files are properly closed if the event handler reads them.
|
# This ensures that files are properly closed if the event handler reads them.
|
||||||
await self._stop_debug_recording_thread()
|
await self._stop_debug_recording_thread()
|
||||||
|
@ -746,9 +753,7 @@ class PipelineRun:
|
||||||
if self.abort_wake_word_detection:
|
if self.abort_wake_word_detection:
|
||||||
raise WakeWordDetectionAborted
|
raise WakeWordDetectionAborted
|
||||||
|
|
||||||
if self.debug_recording_queue is not None:
|
self._capture_chunk(chunk.audio)
|
||||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
|
||||||
|
|
||||||
yield chunk.audio, chunk.timestamp_ms
|
yield chunk.audio, chunk.timestamp_ms
|
||||||
|
|
||||||
# Wake-word-detection occurs *after* the wake word was actually
|
# Wake-word-detection occurs *after* the wake word was actually
|
||||||
|
@ -870,8 +875,7 @@ class PipelineRun:
|
||||||
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||||
sent_vad_start = False
|
sent_vad_start = False
|
||||||
async for chunk in audio_stream:
|
async for chunk in audio_stream:
|
||||||
if self.debug_recording_queue is not None:
|
self._capture_chunk(chunk.audio)
|
||||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
|
||||||
|
|
||||||
if stt_vad is not None:
|
if stt_vad is not None:
|
||||||
if not stt_vad.process(chunk_seconds, chunk.is_speech):
|
if not stt_vad.process(chunk_seconds, chunk.is_speech):
|
||||||
|
@ -1057,7 +1061,28 @@ class PipelineRun:
|
||||||
|
|
||||||
return tts_media.url
|
return tts_media.url
|
||||||
|
|
||||||
def _start_debug_recording_thread(self, device_id: str | None) -> None:
|
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
||||||
|
"""Forward audio chunk to various capturing mechanisms."""
|
||||||
|
if self.debug_recording_queue is not None:
|
||||||
|
# Forward to debug WAV file recording
|
||||||
|
self.debug_recording_queue.put_nowait(audio_bytes)
|
||||||
|
|
||||||
|
if self._device_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Forward to device audio capture
|
||||||
|
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||||
|
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
|
||||||
|
if audio_queue is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio_queue.queue.put_nowait(audio_bytes)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
audio_queue.overflow = True
|
||||||
|
_LOGGER.warning("Audio queue full for device %s", self._device_id)
|
||||||
|
|
||||||
|
def _start_debug_recording_thread(self) -> None:
|
||||||
"""Start thread to record wake/stt audio if debug_recording_dir is set."""
|
"""Start thread to record wake/stt audio if debug_recording_dir is set."""
|
||||||
if self.debug_recording_thread is not None:
|
if self.debug_recording_thread is not None:
|
||||||
# Already started
|
# Already started
|
||||||
|
@ -1068,7 +1093,7 @@ class PipelineRun:
|
||||||
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
|
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
|
||||||
CONF_DEBUG_RECORDING_DIR
|
CONF_DEBUG_RECORDING_DIR
|
||||||
):
|
):
|
||||||
if device_id is None:
|
if self._device_id is None:
|
||||||
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
||||||
run_recording_dir = (
|
run_recording_dir = (
|
||||||
Path(debug_recording_dir)
|
Path(debug_recording_dir)
|
||||||
|
@ -1079,7 +1104,7 @@ class PipelineRun:
|
||||||
# <debug_recording_dir>/<device_id>/<pipeline.name>/<run.id>
|
# <debug_recording_dir>/<device_id>/<pipeline.name>/<run.id>
|
||||||
run_recording_dir = (
|
run_recording_dir = (
|
||||||
Path(debug_recording_dir)
|
Path(debug_recording_dir)
|
||||||
/ device_id
|
/ self._device_id
|
||||||
/ self.pipeline.name
|
/ self.pipeline.name
|
||||||
/ str(time.monotonic_ns())
|
/ str(time.monotonic_ns())
|
||||||
)
|
)
|
||||||
|
@ -1100,8 +1125,8 @@ class PipelineRun:
|
||||||
# Not running
|
# Not running
|
||||||
return
|
return
|
||||||
|
|
||||||
# Signal thread to stop gracefully
|
# NOTE: Expecting a None to have been put in self.debug_recording_queue
|
||||||
self.debug_recording_queue.put(None)
|
# in self.end() to signal the thread to stop.
|
||||||
|
|
||||||
# Wait until the thread has finished to ensure that files are fully written
|
# Wait until the thread has finished to ensure that files are fully written
|
||||||
await self.hass.async_add_executor_job(self.debug_recording_thread.join)
|
await self.hass.async_add_executor_job(self.debug_recording_thread.join)
|
||||||
|
@ -1632,6 +1657,20 @@ class PipelineRuns:
|
||||||
pipeline_run.abort_wake_word_detection = True
|
pipeline_run.abort_wake_word_detection = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceAudioQueue:
|
||||||
|
"""Audio capture queue for a satellite device."""
|
||||||
|
|
||||||
|
queue: asyncio.Queue[bytes | None]
|
||||||
|
"""Queue of audio chunks (None = stop signal)"""
|
||||||
|
|
||||||
|
id: str = field(default_factory=ulid_util.ulid)
|
||||||
|
"""Unique id to ensure the correct audio queue is cleaned up in websocket API."""
|
||||||
|
|
||||||
|
overflow: bool = False
|
||||||
|
"""Flag to be set if audio samples were dropped because the queue was full."""
|
||||||
|
|
||||||
|
|
||||||
class PipelineData:
|
class PipelineData:
|
||||||
"""Store and debug data stored in hass.data."""
|
"""Store and debug data stored in hass.data."""
|
||||||
|
|
||||||
|
@ -1641,6 +1680,7 @@ class PipelineData:
|
||||||
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
|
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
|
||||||
self.pipeline_devices: set[str] = set()
|
self.pipeline_devices: set[str] = set()
|
||||||
self.pipeline_runs = PipelineRuns(pipeline_store)
|
self.pipeline_runs = PipelineRuns(pipeline_store)
|
||||||
|
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -3,22 +3,31 @@ import asyncio
|
||||||
|
|
||||||
# Suppressing disable=deprecated-module is needed for Python 3.11
|
# Suppressing disable=deprecated-module is needed for Python 3.11
|
||||||
import audioop # pylint: disable=deprecated-module
|
import audioop # pylint: disable=deprecated-module
|
||||||
|
import base64
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
import math
|
||||||
|
from typing import Any, Final
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.util import language as language_util
|
from homeassistant.util import language as language_util
|
||||||
|
|
||||||
from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
|
from .const import (
|
||||||
|
DEFAULT_PIPELINE_TIMEOUT,
|
||||||
|
DEFAULT_WAKE_WORD_TIMEOUT,
|
||||||
|
DOMAIN,
|
||||||
|
EVENT_RECORDING,
|
||||||
|
)
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
AudioSettings,
|
AudioSettings,
|
||||||
|
DeviceAudioQueue,
|
||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineError,
|
PipelineError,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
|
@ -32,6 +41,11 @@ from .pipeline import (
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CAPTURE_RATE: Final = 16000
|
||||||
|
CAPTURE_WIDTH: Final = 2
|
||||||
|
CAPTURE_CHANNELS: Final = 1
|
||||||
|
MAX_CAPTURE_TIMEOUT: Final = 60.0
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
|
@ -40,6 +54,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
websocket_api.async_register_command(hass, websocket_list_languages)
|
websocket_api.async_register_command(hass, websocket_list_languages)
|
||||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||||
websocket_api.async_register_command(hass, websocket_get_run)
|
websocket_api.async_register_command(hass, websocket_get_run)
|
||||||
|
websocket_api.async_register_command(hass, websocket_device_capture)
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
|
@ -371,3 +386,101 @@ async def websocket_list_languages(
|
||||||
else pipeline_languages
|
else pipeline_languages
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.require_admin
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_pipeline/device/capture",
|
||||||
|
vol.Required("device_id"): str,
|
||||||
|
vol.Required("timeout"): vol.All(
|
||||||
|
# 0 < timeout <= MAX_CAPTURE_TIMEOUT
|
||||||
|
vol.Coerce(float),
|
||||||
|
vol.Range(min=0, min_included=False, max=MAX_CAPTURE_TIMEOUT),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_device_capture(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Capture raw audio from a satellite device and forward to client."""
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
device_id = msg["device_id"]
|
||||||
|
|
||||||
|
# Number of seconds to record audio in wall clock time
|
||||||
|
timeout_seconds = msg["timeout"]
|
||||||
|
|
||||||
|
# We don't know the chunk size, so the upper bound is calculated assuming a
|
||||||
|
# single sample (16 bits) per queue item.
|
||||||
|
max_queue_items = (
|
||||||
|
# +1 for None to signal end
|
||||||
|
int(math.ceil(timeout_seconds * CAPTURE_RATE))
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_queue = DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items))
|
||||||
|
|
||||||
|
# Running simultaneous captures for a single device will not work by design.
|
||||||
|
# The new capture will cause the old capture to stop.
|
||||||
|
if (
|
||||||
|
old_audio_queue := pipeline_data.device_audio_queues.pop(device_id, None)
|
||||||
|
) is not None:
|
||||||
|
with contextlib.suppress(asyncio.QueueFull):
|
||||||
|
# Signal other websocket command that we're taking over
|
||||||
|
old_audio_queue.queue.put_nowait(None)
|
||||||
|
|
||||||
|
# Only one client can be capturing audio at a time
|
||||||
|
pipeline_data.device_audio_queues[device_id] = audio_queue
|
||||||
|
|
||||||
|
def clean_up_queue() -> None:
|
||||||
|
# Clean up our audio queue
|
||||||
|
maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id)
|
||||||
|
if (maybe_audio_queue is not None) and (maybe_audio_queue.id == audio_queue.id):
|
||||||
|
# Only pop if this is our queue
|
||||||
|
pipeline_data.device_audio_queues.pop(device_id)
|
||||||
|
|
||||||
|
# Unsubscribe cleans up queue
|
||||||
|
connection.subscriptions[msg["id"]] = clean_up_queue
|
||||||
|
|
||||||
|
# Audio will follow as events
|
||||||
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
# Record to logbook
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_RECORDING,
|
||||||
|
{
|
||||||
|
ATTR_DEVICE_ID: device_id,
|
||||||
|
ATTR_SECONDS: timeout_seconds,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with contextlib.suppress(asyncio.TimeoutError):
|
||||||
|
async with asyncio.timeout(timeout_seconds):
|
||||||
|
while True:
|
||||||
|
# Send audio chunks encoded as base64
|
||||||
|
audio_bytes = await audio_queue.queue.get()
|
||||||
|
if audio_bytes is None:
|
||||||
|
# Signal to stop
|
||||||
|
break
|
||||||
|
|
||||||
|
connection.send_event(
|
||||||
|
msg["id"],
|
||||||
|
{
|
||||||
|
"type": "audio",
|
||||||
|
"rate": CAPTURE_RATE, # hertz
|
||||||
|
"width": CAPTURE_WIDTH, # bytes
|
||||||
|
"channels": CAPTURE_CHANNELS,
|
||||||
|
"audio": base64.b64encode(audio_bytes).decode("ascii"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture has ended
|
||||||
|
connection.send_event(
|
||||||
|
msg["id"], {"type": "end", "overflow": audio_queue.overflow}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clean_up_queue()
|
||||||
|
|
|
@ -487,6 +487,119 @@
|
||||||
# name: test_audio_pipeline_with_wake_word_timeout.3
|
# name: test_audio_pipeline_with_wake_word_timeout.3
|
||||||
None
|
None
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_device_capture
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture.1
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'language': 'en-US',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture.2
|
||||||
|
None
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_override
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_override.1
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'language': 'en-US',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_override.2
|
||||||
|
dict({
|
||||||
|
'audio': 'Y2h1bmsx',
|
||||||
|
'channels': 1,
|
||||||
|
'rate': 16000,
|
||||||
|
'type': 'audio',
|
||||||
|
'width': 2,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_override.3
|
||||||
|
dict({
|
||||||
|
'stt_output': dict({
|
||||||
|
'text': 'test transcript',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_override.4
|
||||||
|
None
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_override.5
|
||||||
|
dict({
|
||||||
|
'overflow': False,
|
||||||
|
'type': 'end',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_queue_full
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_queue_full.1
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'language': 'en-US',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_queue_full.2
|
||||||
|
dict({
|
||||||
|
'stt_output': dict({
|
||||||
|
'text': 'test transcript',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_queue_full.3
|
||||||
|
None
|
||||||
|
# ---
|
||||||
|
# name: test_device_capture_queue_full.4
|
||||||
|
dict({
|
||||||
|
'overflow': True,
|
||||||
|
'type': 'end',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
# name: test_intent_failed
|
# name: test_intent_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
|
|
42
tests/components/assist_pipeline/test_logbook.py
Normal file
42
tests/components/assist_pipeline/test_logbook.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
"""The tests for assist_pipeline logbook."""
|
||||||
|
from homeassistant.components import assist_pipeline, logbook
|
||||||
|
from homeassistant.const import ATTR_DEVICE_ID
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import device_registry as dr
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.components.logbook.common import MockRow, mock_humanify
|
||||||
|
|
||||||
|
|
||||||
|
async def test_recording_event(
|
||||||
|
hass: HomeAssistant, init_components, device_registry: dr.DeviceRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test recording event."""
|
||||||
|
hass.config.components.add("recorder")
|
||||||
|
assert await async_setup_component(hass, "logbook", {})
|
||||||
|
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
satellite_device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections=set(),
|
||||||
|
identifiers={("demo", "satellite-1234")},
|
||||||
|
)
|
||||||
|
|
||||||
|
device_registry.async_update_device(satellite_device.id, name="My Satellite")
|
||||||
|
event = mock_humanify(
|
||||||
|
hass,
|
||||||
|
[
|
||||||
|
MockRow(
|
||||||
|
assist_pipeline.EVENT_RECORDING,
|
||||||
|
{ATTR_DEVICE_ID: satellite_device.id},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert event[logbook.LOGBOOK_ENTRY_NAME] == "My Satellite"
|
||||||
|
assert event[logbook.LOGBOOK_ENTRY_DOMAIN] == assist_pipeline.DOMAIN
|
||||||
|
assert (
|
||||||
|
event[logbook.LOGBOOK_ENTRY_MESSAGE] == "My Satellite started recording audio"
|
||||||
|
)
|
|
@ -1,16 +1,23 @@
|
||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, patch
|
||||||
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
|
DeviceAudioQueue,
|
||||||
|
Pipeline,
|
||||||
|
PipelineData,
|
||||||
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import device_registry as dr
|
||||||
|
|
||||||
from .conftest import MockWakeWordEntity, MockWakeWordEntity2
|
from .conftest import MockWakeWordEntity, MockWakeWordEntity2
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@ -2104,3 +2111,344 @@ async def test_wake_word_cooldown_different_entities(
|
||||||
|
|
||||||
# Wake words should be the same
|
# Wake words should be the same
|
||||||
assert ww_id_1 == ww_id_2
|
assert ww_id_1 == ww_id_2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_capture(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test audio capture from a satellite device."""
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
satellite_device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections=set(),
|
||||||
|
identifiers={("demo", "satellite-1234")},
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
|
||||||
|
|
||||||
|
# Start capture
|
||||||
|
client_capture = await hass_ws_client(hass)
|
||||||
|
await client_capture.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/device/capture",
|
||||||
|
"timeout": 30,
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_capture.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
client_pipeline = await hass_ws_client(hass)
|
||||||
|
await client_pipeline.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "stt",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
for audio_chunk in audio_chunks:
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk)
|
||||||
|
|
||||||
|
# End of audio stream
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
|
||||||
|
# run end
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Verify capture
|
||||||
|
events = []
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
while True:
|
||||||
|
msg = await client_capture.receive_json()
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
event_data = msg["event"]
|
||||||
|
events.append(event_data)
|
||||||
|
if event_data["type"] == "end":
|
||||||
|
break
|
||||||
|
|
||||||
|
assert len(events) == len(audio_chunks) + 1
|
||||||
|
|
||||||
|
# Verify audio chunks
|
||||||
|
for i, audio_chunk in enumerate(audio_chunks):
|
||||||
|
assert events[i]["type"] == "audio"
|
||||||
|
assert events[i]["rate"] == 16000
|
||||||
|
assert events[i]["width"] == 2
|
||||||
|
assert events[i]["channels"] == 1
|
||||||
|
|
||||||
|
# Audio is base64 encoded
|
||||||
|
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
||||||
|
|
||||||
|
# Last event is the end
|
||||||
|
assert events[-1]["type"] == "end"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_capture_override(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test overriding an existing audio capture from a satellite device."""
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
satellite_device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections=set(),
|
||||||
|
identifiers={("demo", "satellite-1234")},
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
|
||||||
|
|
||||||
|
# Start first capture
|
||||||
|
client_capture_1 = await hass_ws_client(hass)
|
||||||
|
await client_capture_1.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/device/capture",
|
||||||
|
"timeout": 30,
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_capture_1.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
client_pipeline = await hass_ws_client(hass)
|
||||||
|
await client_pipeline.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "stt",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Send first audio chunk
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunks[0])
|
||||||
|
|
||||||
|
# Verify first capture
|
||||||
|
msg = await client_capture_1.receive_json()
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == snapshot
|
||||||
|
assert msg["event"]["audio"] == base64.b64encode(audio_chunks[0]).decode("ascii")
|
||||||
|
|
||||||
|
# Start a new capture
|
||||||
|
client_capture_2 = await hass_ws_client(hass)
|
||||||
|
await client_capture_2.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/device/capture",
|
||||||
|
"timeout": 30,
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result (capture 2)
|
||||||
|
msg = await client_capture_2.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# Send remaining audio chunks
|
||||||
|
for audio_chunk in audio_chunks[1:]:
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk)
|
||||||
|
|
||||||
|
# End of audio stream
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# run end
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Verify that first capture ended with no more audio
|
||||||
|
msg = await client_capture_1.receive_json()
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == snapshot
|
||||||
|
assert msg["event"]["type"] == "end"
|
||||||
|
|
||||||
|
# Verify that the second capture got the remaining audio
|
||||||
|
events = []
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
while True:
|
||||||
|
msg = await client_capture_2.receive_json()
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
event_data = msg["event"]
|
||||||
|
events.append(event_data)
|
||||||
|
if event_data["type"] == "end":
|
||||||
|
break
|
||||||
|
|
||||||
|
# -1 since first audio chunk went to the first capture
|
||||||
|
assert len(events) == len(audio_chunks)
|
||||||
|
|
||||||
|
# Verify all but first audio chunk
|
||||||
|
for i, audio_chunk in enumerate(audio_chunks[1:]):
|
||||||
|
assert events[i]["type"] == "audio"
|
||||||
|
assert events[i]["rate"] == 16000
|
||||||
|
assert events[i]["width"] == 2
|
||||||
|
assert events[i]["channels"] == 1
|
||||||
|
|
||||||
|
# Audio is base64 encoded
|
||||||
|
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
||||||
|
|
||||||
|
# Last event is the end
|
||||||
|
assert events[-1]["type"] == "end"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_capture_queue_full(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test audio capture from a satellite device when the recording queue fills up."""
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
satellite_device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections=set(),
|
||||||
|
identifiers={("demo", "satellite-1234")},
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeQueue(asyncio.Queue):
|
||||||
|
"""Queue that reports full for anything but None."""
|
||||||
|
|
||||||
|
def put_nowait(self, item):
|
||||||
|
if item is not None:
|
||||||
|
raise asyncio.QueueFull()
|
||||||
|
|
||||||
|
super().put_nowait(item)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.websocket_api.DeviceAudioQueue"
|
||||||
|
) as mock:
|
||||||
|
mock.return_value = DeviceAudioQueue(queue=FakeQueue())
|
||||||
|
|
||||||
|
# Start capture
|
||||||
|
client_capture = await hass_ws_client(hass)
|
||||||
|
await client_capture.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/device/capture",
|
||||||
|
"timeout": 30,
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_capture.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
client_pipeline = await hass_ws_client(hass)
|
||||||
|
await client_pipeline.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "stt",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
"device_id": satellite_device.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Single sample will "overflow" the queue
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id, 0, 0]))
|
||||||
|
|
||||||
|
# End of audio stream
|
||||||
|
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Queue should have been overflowed
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
msg = await client_capture.receive_json()
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == snapshot
|
||||||
|
assert msg["event"]["type"] == "end"
|
||||||
|
assert msg["event"]["overflow"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue