Add wake word cooldown to avoid duplicate wake-ups (#101417)

This commit is contained in:
Michael Hansen 2023-10-06 02:18:35 -05:00 committed by GitHub
parent 48a23798d0
commit 244f6d8002
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 198 additions and 18 deletions

View file

@ -9,7 +9,7 @@ from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from .const import DATA_CONFIG, DOMAIN
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
@ -45,7 +45,9 @@ __all__ = (
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{vol.Optional("debug_recording_dir"): str},
{
vol.Optional(CONF_DEBUG_RECORDING_DIR): str,
},
)
},
extra=vol.ALLOW_EXTRA,

View file

@ -2,3 +2,12 @@
DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config"
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds

View file

@ -48,7 +48,13 @@ from homeassistant.util import (
)
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DATA_CONFIG, DOMAIN
from .const import (
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DEFAULT_WAKE_WORD_COOLDOWN,
DOMAIN,
)
from .error import (
IntentRecognitionError,
PipelineError,
@ -399,6 +405,9 @@ class WakeWordSettings:
audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT."""
cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
"""Seconds after a wake word detection where other detections are ignored."""
@dataclass(frozen=True)
class AudioSettings:
@ -603,6 +612,8 @@ class PipelineRun:
)
)
wake_word_settings = self.wake_word_settings or WakeWordSettings()
# Remove language since it doesn't apply to wake words yet
metadata_dict.pop("language", None)
@ -612,6 +623,7 @@ class PipelineRun:
{
"entity_id": self.wake_word_entity_id,
"metadata": metadata_dict,
"timeout": wake_word_settings.timeout or 0,
},
)
)
@ -619,8 +631,6 @@ class PipelineRun:
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
wake_word_settings = self.wake_word_settings or WakeWordSettings()
wake_word_vad: VoiceActivityTimeout | None = None
if (wake_word_settings.timeout is not None) and (
wake_word_settings.timeout > 0
@ -670,6 +680,17 @@ class PipelineRun:
if result is None:
wake_word_output: dict[str, Any] = {}
else:
# Avoid duplicate detections by checking cooldown
last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
_LOGGER.debug("Duplicate wake word detection occurred")
raise WakeWordDetectionAborted
# Record last wake up time to block duplicate detections
self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic()
if result.queued_audio:
# Add audio that was pending at detection.
#
@ -1032,7 +1053,7 @@ class PipelineRun:
# Directory to save audio for each pipeline run.
# Configured in YAML for assist_pipeline.
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
"debug_recording_dir"
CONF_DEBUG_RECORDING_DIR
):
if device_id is None:
# <debug_recording_dir>/<pipeline.name>/<run.id>

View file

@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util
from .const import DOMAIN
from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
@ -30,9 +30,6 @@ from .pipeline import (
async_get_pipeline,
)
DEFAULT_TIMEOUT = 60 * 5 # seconds
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
_LOGGER = logging.getLogger(__name__)
@ -117,7 +114,7 @@ async def websocket_run(
)
return
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
timeout = msg.get("timeout", DEFAULT_PIPELINE_TIMEOUT)
start_stage = PipelineStage(msg["start_stage"])
end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None

View file

@ -285,6 +285,7 @@
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
'timeout': 0,
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
@ -396,6 +397,7 @@
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
'timeout': 0,
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),

View file

@ -373,6 +373,7 @@
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 0,
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.2
@ -474,6 +475,7 @@
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 1,
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.2
@ -655,3 +657,63 @@
# name: test_tts_failed.2
None
# ---
# name: test_wake_word_cooldown
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown.2
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown.3
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown.4
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww',
}),
})
# ---
# name: test_wake_word_cooldown.5
dict({
'code': 'wake_word_detection_aborted',
'message': '',
})
# ---

View file

@ -10,6 +10,10 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
@ -395,8 +399,8 @@ async def test_pipeline_save_audio(
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
"assist_pipeline",
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
pipeline = assist_pipeline.async_get_pipeline(hass)
@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id(
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
"assist_pipeline",
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):
@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error(
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
"assist_pipeline",
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):

View file

@ -9,6 +9,8 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from .conftest import MockWakeWordEntity
from tests.typing import WebSocketGenerator
@ -266,7 +268,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
events.append(msg["event"])
# "audio"
await client.send_bytes(bytes([1]) + b"wake word")
await client.send_bytes(bytes([handler_id]) + b"wake word")
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-end"
@ -1805,3 +1807,84 @@ async def test_audio_pipeline_with_enhancements(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_wake_word_cooldown(
hass: HomeAssistant,
init_components,
mock_wake_word_provider_entity: MockWakeWordEntity,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test that duplicate wake word detections are blocked during the cooldown period."""
client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass)
await client_1.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
await client_2.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
# result
msg = await client_1.receive_json()
assert msg["success"], msg
msg = await client_2.receive_json()
assert msg["success"], msg
# run start
msg = await client_1.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
# wake_word
msg = await client_1.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
# Wake both up at the same time
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
# Get response events
msg = await client_1.receive_json()
event_type_1 = msg["event"]["type"]
msg = await client_2.receive_json()
event_type_2 = msg["event"]["type"]
# One should be a wake up, one should be an error
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}