From 244f6d800220f8aef2e2b4bf0eb2d2ae12a6393f Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 6 Oct 2023 02:18:35 -0500 Subject: [PATCH] Add wake word cooldown to avoid duplicate wake-ups (#101417) --- .../components/assist_pipeline/__init__.py | 6 +- .../components/assist_pipeline/const.py | 9 ++ .../components/assist_pipeline/pipeline.py | 29 ++++++- .../assist_pipeline/websocket_api.py | 7 +- .../assist_pipeline/snapshots/test_init.ambr | 2 + .../snapshots/test_websocket.ambr | 62 ++++++++++++++ tests/components/assist_pipeline/test_init.py | 16 ++-- .../assist_pipeline/test_websocket.py | 85 ++++++++++++++++++- 8 files changed, 198 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 9a61346f673..fab4c3178bc 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -9,7 +9,7 @@ from homeassistant.components import stt from homeassistant.core import Context, HomeAssistant from homeassistant.helpers.typing import ConfigType -from .const import DATA_CONFIG, DOMAIN +from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN from .error import PipelineNotFound from .pipeline import ( AudioSettings, @@ -45,7 +45,9 @@ __all__ = ( CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.Schema( - {vol.Optional("debug_recording_dir"): str}, + { + vol.Optional(CONF_DEBUG_RECORDING_DIR): str, + }, ) }, extra=vol.ALLOW_EXTRA, diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index e21d9003a69..84b49fc18fa 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -2,3 +2,12 @@ DOMAIN = "assist_pipeline" DATA_CONFIG = f"{DOMAIN}.config" + +DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds + +DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds + +CONF_DEBUG_RECORDING_DIR = "debug_recording_dir" + +DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up" +DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 76444fb2436..6ec031baf3b 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -48,7 +48,13 @@ from homeassistant.util import ( ) from homeassistant.util.limited_size_dict import LimitedSizeDict -from .const import DATA_CONFIG, DOMAIN +from .const import ( + CONF_DEBUG_RECORDING_DIR, + DATA_CONFIG, + DATA_LAST_WAKE_UP, + DEFAULT_WAKE_WORD_COOLDOWN, + DOMAIN, +) from .error import ( IntentRecognitionError, PipelineError, @@ -399,6 +405,9 @@ class WakeWordSettings: audio_seconds_to_buffer: float = 0 """Seconds of audio to buffer before detection and forward to STT.""" + cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN + """Seconds after a wake word detection where other detections are ignored.""" + @dataclass(frozen=True) class AudioSettings: @@ -603,6 +612,8 @@ class PipelineRun: ) ) + wake_word_settings = self.wake_word_settings or WakeWordSettings() + # Remove language since it doesn't apply to wake words yet metadata_dict.pop("language", None) @@ -612,6 +623,7 @@ class PipelineRun: { "entity_id": self.wake_word_entity_id, "metadata": metadata_dict, + "timeout": wake_word_settings.timeout or 0, }, ) ) @@ -619,8 +631,6 @@ class PipelineRun: if self.debug_recording_queue is not None: self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}") - wake_word_settings = self.wake_word_settings or WakeWordSettings() - wake_word_vad: VoiceActivityTimeout | None = None if (wake_word_settings.timeout is not None) and ( wake_word_settings.timeout > 0 @@ -670,6 +680,17 @@ class PipelineRun: if result is None: wake_word_output: dict[str, Any] = {} else: + # Avoid duplicate detections by checking cooldown + last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP) + if last_wake_up is not None: + sec_since_last_wake_up = time.monotonic() - last_wake_up + if sec_since_last_wake_up < wake_word_settings.cooldown_seconds: + _LOGGER.debug("Duplicate wake word detection occurred") + raise WakeWordDetectionAborted + + # Record last wake up time to block duplicate detections + self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic() + if result.queued_audio: # Add audio that was pending at detection. # @@ -1032,7 +1053,7 @@ class PipelineRun: # Directory to save audio for each pipeline run. # Configured in YAML for assist_pipeline. if debug_recording_dir := self.hass.data[DATA_CONFIG].get( - "debug_recording_dir" + CONF_DEBUG_RECORDING_DIR ): if device_id is None: # // diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 798843ea6e3..fda3e266490 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.util import language as language_util -from .const import DOMAIN +from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN from .error import PipelineNotFound from .pipeline import ( AudioSettings, @@ -30,9 +30,6 @@ from .pipeline import ( async_get_pipeline, ) -DEFAULT_TIMEOUT = 60 * 5 # seconds -DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds - _LOGGER = logging.getLogger(__name__) @@ -117,7 +114,7 @@ async def websocket_run( ) return - timeout = msg.get("timeout", DEFAULT_TIMEOUT) + timeout = msg.get("timeout", DEFAULT_PIPELINE_TIMEOUT) start_stage = PipelineStage(msg["start_stage"]) end_stage = PipelineStage(msg["end_stage"]) handler_id: int | None = None diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 3f0582f2bfb..e822759d208 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -285,6 +285,7 @@ 'format': , 'sample_rate': , }), + 'timeout': 0, }), 'type': , }), @@ -396,6 +397,7 @@ 'format': , 'sample_rate': , }), + 'timeout': 0, }), 'type': , }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 7cecf9fed40..b8c668f3fd0 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -373,6 +373,7 @@ 'format': 'wav', 'sample_rate': 16000, }), + 'timeout': 0, }) # --- # name: test_audio_pipeline_with_wake_word_no_timeout.2 @@ -474,6 +475,7 @@ 'format': 'wav', 'sample_rate': 16000, }), + 'timeout': 1, }) # --- # name: test_audio_pipeline_with_wake_word_timeout.2 @@ -655,3 +657,63 @@ # name: test_tts_failed.2 None # --- +# name: test_wake_word_cooldown + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown.1 + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown.2 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown.3 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown.4 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'wake_word_id': 'test_ww', + }), + }) +# --- +# name: test_wake_word_cooldown.5 + dict({ + 'code': 'wake_word_detection_aborted', + 'message': '', + }) +# --- diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 98ecae628f1..a98858a1bce 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -10,6 +10,10 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import assist_pipeline, stt +from homeassistant.components.assist_pipeline.const import ( + CONF_DEBUG_RECORDING_DIR, + DOMAIN, +) from homeassistant.core import Context, HomeAssistant from homeassistant.setup import async_setup_component @@ -395,8 +399,8 @@ async def test_pipeline_save_audio( temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, - "assist_pipeline", - {"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, + DOMAIN, + {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) pipeline = assist_pipeline.async_get_pipeline(hass) @@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id( temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, - "assist_pipeline", - {"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, + DOMAIN, + {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): @@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error( temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, - "assist_pipeline", - {"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, + DOMAIN, + {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index f995a0d3577..28b31e5b19c 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -9,6 +9,8 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from .conftest import MockWakeWordEntity + from tests.typing import WebSocketGenerator @@ -266,7 +268,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout( events.append(msg["event"]) # "audio" - await client.send_bytes(bytes([1]) + b"wake word") + await client.send_bytes(bytes([handler_id]) + b"wake word") msg = await client.receive_json() assert msg["event"]["type"] == "wake_word-end" @@ -1805,3 +1807,84 @@ async def test_audio_pipeline_with_enhancements( msg = await client.receive_json() assert msg["success"] assert msg["result"] == {"events": events} + + +async def test_wake_word_cooldown( + hass: HomeAssistant, + init_components, + mock_wake_word_provider_entity: MockWakeWordEntity, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test that duplicate wake word detections are blocked during the cooldown period.""" + client_1 = await hass_ws_client(hass) + client_2 = await hass_ws_client(hass) + + await client_1.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + } + ) + + await client_2.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + } + ) + + # result + msg = await client_1.receive_json() + assert msg["success"], msg + + msg = await client_2.receive_json() + assert msg["success"], msg + + # run start + msg = await client_1.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + assert msg["event"]["data"] == snapshot + + # wake_word + msg = await client_1.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + + # Wake both up at the same time + await client_1.send_bytes(bytes([handler_id_1]) + b"wake word") + await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") + + # Get response events + msg = await client_1.receive_json() + event_type_1 = msg["event"]["type"] + + msg = await client_2.receive_json() + event_type_2 = msg["event"]["type"] + + # One should be a wake up, one should be an error + assert {event_type_1, event_type_2} == {"wake_word-end", "error"}