Add wake word cooldown to avoid duplicate wake-ups (#101417)
This commit is contained in:
parent
48a23798d0
commit
244f6d8002
8 changed files with 198 additions and 18 deletions
|
@ -9,7 +9,7 @@ from homeassistant.components import stt
|
|||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
|
@ -45,7 +45,9 @@ __all__ = (
|
|||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{vol.Optional("debug_recording_dir"): str},
|
||||
{
|
||||
vol.Optional(CONF_DEBUG_RECORDING_DIR): str,
|
||||
},
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
|
|
|
@ -2,3 +2,12 @@
|
|||
DOMAIN = "assist_pipeline"
|
||||
|
||||
DATA_CONFIG = f"{DOMAIN}.config"
|
||||
|
||||
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
|
||||
|
||||
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
|
||||
|
||||
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
||||
|
||||
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
||||
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
|
||||
|
|
|
@ -48,7 +48,13 @@ from homeassistant.util import (
|
|||
)
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DATA_CONFIG,
|
||||
DATA_LAST_WAKE_UP,
|
||||
DEFAULT_WAKE_WORD_COOLDOWN,
|
||||
DOMAIN,
|
||||
)
|
||||
from .error import (
|
||||
IntentRecognitionError,
|
||||
PipelineError,
|
||||
|
@ -399,6 +405,9 @@ class WakeWordSettings:
|
|||
audio_seconds_to_buffer: float = 0
|
||||
"""Seconds of audio to buffer before detection and forward to STT."""
|
||||
|
||||
cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
|
||||
"""Seconds after a wake word detection where other detections are ignored."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioSettings:
|
||||
|
@ -603,6 +612,8 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
wake_word_settings = self.wake_word_settings or WakeWordSettings()
|
||||
|
||||
# Remove language since it doesn't apply to wake words yet
|
||||
metadata_dict.pop("language", None)
|
||||
|
||||
|
@ -612,6 +623,7 @@ class PipelineRun:
|
|||
{
|
||||
"entity_id": self.wake_word_entity_id,
|
||||
"metadata": metadata_dict,
|
||||
"timeout": wake_word_settings.timeout or 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
@ -619,8 +631,6 @@ class PipelineRun:
|
|||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
|
||||
|
||||
wake_word_settings = self.wake_word_settings or WakeWordSettings()
|
||||
|
||||
wake_word_vad: VoiceActivityTimeout | None = None
|
||||
if (wake_word_settings.timeout is not None) and (
|
||||
wake_word_settings.timeout > 0
|
||||
|
@ -670,6 +680,17 @@ class PipelineRun:
|
|||
if result is None:
|
||||
wake_word_output: dict[str, Any] = {}
|
||||
else:
|
||||
# Avoid duplicate detections by checking cooldown
|
||||
last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP)
|
||||
if last_wake_up is not None:
|
||||
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
||||
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
|
||||
_LOGGER.debug("Duplicate wake word detection occurred")
|
||||
raise WakeWordDetectionAborted
|
||||
|
||||
# Record last wake up time to block duplicate detections
|
||||
self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic()
|
||||
|
||||
if result.queued_audio:
|
||||
# Add audio that was pending at detection.
|
||||
#
|
||||
|
@ -1032,7 +1053,7 @@ class PipelineRun:
|
|||
# Directory to save audio for each pipeline run.
|
||||
# Configured in YAML for assist_pipeline.
|
||||
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
|
||||
"debug_recording_dir"
|
||||
CONF_DEBUG_RECORDING_DIR
|
||||
):
|
||||
if device_id is None:
|
||||
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
||||
|
|
|
@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
|
@ -30,9 +30,6 @@ from .pipeline import (
|
|||
async_get_pipeline,
|
||||
)
|
||||
|
||||
DEFAULT_TIMEOUT = 60 * 5 # seconds
|
||||
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -117,7 +114,7 @@ async def websocket_run(
|
|||
)
|
||||
return
|
||||
|
||||
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
||||
timeout = msg.get("timeout", DEFAULT_PIPELINE_TIMEOUT)
|
||||
start_stage = PipelineStage(msg["start_stage"])
|
||||
end_stage = PipelineStage(msg["end_stage"])
|
||||
handler_id: int | None = None
|
||||
|
|
|
@ -285,6 +285,7 @@
|
|||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
'timeout': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
|
@ -396,6 +397,7 @@
|
|||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
'timeout': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
|
|
|
@ -373,6 +373,7 @@
|
|||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
'timeout': 0,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.2
|
||||
|
@ -474,6 +475,7 @@
|
|||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
'timeout': 1,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout.2
|
||||
|
@ -655,3 +657,63 @@
|
|||
# name: test_tts_failed.2
|
||||
None
|
||||
# ---
|
||||
# name: test_wake_word_cooldown
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown.1
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown.2
|
||||
dict({
|
||||
'entity_id': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
'timeout': 3,
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown.3
|
||||
dict({
|
||||
'entity_id': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
'timeout': 3,
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown.4
|
||||
dict({
|
||||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown.5
|
||||
dict({
|
||||
'code': 'wake_word_detection_aborted',
|
||||
'message': '',
|
||||
})
|
||||
# ---
|
||||
|
|
|
@ -10,6 +10,10 @@ import pytest
|
|||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
@ -395,8 +399,8 @@ async def test_pipeline_save_audio(
|
|||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass)
|
||||
|
@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id(
|
|||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
|
@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error(
|
|||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
|
|
|
@ -9,6 +9,8 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .conftest import MockWakeWordEntity
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
|
@ -266,7 +268,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||
events.append(msg["event"])
|
||||
|
||||
# "audio"
|
||||
await client.send_bytes(bytes([1]) + b"wake word")
|
||||
await client.send_bytes(bytes([handler_id]) + b"wake word")
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-end"
|
||||
|
@ -1805,3 +1807,84 @@ async def test_audio_pipeline_with_enhancements(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_wake_word_cooldown(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that duplicate wake word detections are blocked during the cooldown period."""
|
||||
client_1 = await hass_ws_client(hass)
|
||||
client_2 = await hass_ws_client(hass)
|
||||
|
||||
await client_1.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"no_vad": True,
|
||||
"no_chunking": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await client_2.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"no_vad": True,
|
||||
"no_chunking": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
# run start
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# wake_word
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Wake both up at the same time
|
||||
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
|
||||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||
|
||||
# Get response events
|
||||
msg = await client_1.receive_json()
|
||||
event_type_1 = msg["event"]["type"]
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
event_type_2 = msg["event"]["type"]
|
||||
|
||||
# One should be a wake up, one should be an error
|
||||
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
||||
|
|
Loading…
Add table
Reference in a new issue