Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Michael Hansen
fbd6cdf9bb Allow changing timeouts via YAML 2023-10-04 10:16:41 -05:00
8 changed files with 101 additions and 20 deletions

View file

@ -9,7 +9,15 @@ from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from .const import DATA_CONFIG, DOMAIN
from .const import (
CONF_DEBUG_RECORDING_DIR,
CONF_PIPELINE_TIMEOUT,
CONF_WAKE_WORD_TIMEOUT,
DATA_CONFIG,
DEFAULT_PIPELINE_TIMEOUT,
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
)
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
@ -45,16 +53,29 @@ __all__ = (
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{vol.Optional("debug_recording_dir"): str},
{
vol.Optional(CONF_DEBUG_RECORDING_DIR): str,
vol.Optional(
CONF_PIPELINE_TIMEOUT, default=DEFAULT_PIPELINE_TIMEOUT
): vol.Any(float, int),
vol.Optional(
CONF_WAKE_WORD_TIMEOUT, default=DEFAULT_WAKE_WORD_TIMEOUT
): vol.Any(float, int),
},
)
},
extra=vol.ALLOW_EXTRA,
)
DEFAULT_CONFIG = {
CONF_PIPELINE_TIMEOUT: DEFAULT_PIPELINE_TIMEOUT,
CONF_WAKE_WORD_TIMEOUT: DEFAULT_WAKE_WORD_TIMEOUT,
}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Assist pipeline integration."""
hass.data[DATA_CONFIG] = config.get(DOMAIN, {})
hass.data[DATA_CONFIG] = config.get(DOMAIN, DEFAULT_CONFIG)
await async_setup_pipeline_store(hass)
async_register_websocket_api(hass)

View file

@ -2,3 +2,11 @@
DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config"
CONF_PIPELINE_TIMEOUT = "pipeline_timeout"
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
CONF_WAKE_WORD_TIMEOUT = "wake_word_timeout"
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"

View file

@ -48,7 +48,7 @@ from homeassistant.util import (
)
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DATA_CONFIG, DOMAIN
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
from .error import (
IntentRecognitionError,
PipelineError,
@ -603,6 +603,8 @@ class PipelineRun:
)
)
wake_word_settings = self.wake_word_settings or WakeWordSettings()
# Remove language since it doesn't apply to wake words yet
metadata_dict.pop("language", None)
@ -612,6 +614,7 @@ class PipelineRun:
{
"entity_id": self.wake_word_entity_id,
"metadata": metadata_dict,
"timeout": wake_word_settings.timeout or 0,
},
)
)
@ -619,8 +622,6 @@ class PipelineRun:
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
wake_word_settings = self.wake_word_settings or WakeWordSettings()
wake_word_vad: VoiceActivityTimeout | None = None
if (wake_word_settings.timeout is not None) and (
wake_word_settings.timeout > 0
@ -1032,7 +1033,7 @@ class PipelineRun:
# Directory to save audio for each pipeline run.
# Configured in YAML for assist_pipeline.
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
"debug_recording_dir"
CONF_DEBUG_RECORDING_DIR
):
if device_id is None:
# <debug_recording_dir>/<pipeline.name>/<run.id>

View file

@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util
from .const import DOMAIN
from .const import CONF_PIPELINE_TIMEOUT, CONF_WAKE_WORD_TIMEOUT, DATA_CONFIG, DOMAIN
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
@ -30,9 +30,6 @@ from .pipeline import (
async_get_pipeline,
)
DEFAULT_TIMEOUT = 60 * 5 # seconds
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
_LOGGER = logging.getLogger(__name__)
@ -117,7 +114,8 @@ async def websocket_run(
)
return
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
config = hass.data[DATA_CONFIG]
timeout = msg.get("timeout", config[CONF_PIPELINE_TIMEOUT])
start_stage = PipelineStage(msg["start_stage"])
end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None
@ -139,7 +137,7 @@ async def websocket_run(
if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
timeout=msg["input"].get("timeout", config[CONF_WAKE_WORD_TIMEOUT]),
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
)

View file

@ -285,6 +285,7 @@
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
'timeout': 0,
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
@ -396,6 +397,7 @@
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
'timeout': 0,
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),

View file

@ -373,6 +373,7 @@
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 0,
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.2
@ -474,6 +475,7 @@
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 1,
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.2

View file

@ -10,6 +10,10 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
@ -395,8 +399,8 @@ async def test_pipeline_save_audio(
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
"assist_pipeline",
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
pipeline = assist_pipeline.async_get_pipeline(hass)
@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id(
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
"assist_pipeline",
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):
@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error(
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
"assist_pipeline",
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):

View file

@ -4,10 +4,15 @@ from unittest.mock import ANY, patch
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.const import (
CONF_PIPELINE_TIMEOUT,
CONF_WAKE_WORD_TIMEOUT,
DOMAIN,
)
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from tests.typing import WebSocketGenerator
@ -1805,3 +1810,43 @@ async def test_audio_pipeline_with_enhancements(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_config_timeouts(
hass: HomeAssistant,
init_supporting_components,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test changing timeouts via YAML config."""
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_PIPELINE_TIMEOUT: 10, CONF_WAKE_WORD_TIMEOUT: 1}},
)
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {"sample_rate": 16000},
}
)
# result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"]["runner_data"]["timeout"] == 10
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"]["timeout"] == 1