diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 9a61346f673..29bfe0813b6 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -9,7 +9,15 @@ from homeassistant.components import stt from homeassistant.core import Context, HomeAssistant from homeassistant.helpers.typing import ConfigType -from .const import DATA_CONFIG, DOMAIN +from .const import ( + CONF_DEBUG_RECORDING_DIR, + CONF_PIPELINE_TIMEOUT, + CONF_WAKE_WORD_TIMEOUT, + DATA_CONFIG, + DEFAULT_PIPELINE_TIMEOUT, + DEFAULT_WAKE_WORD_TIMEOUT, + DOMAIN, +) from .error import PipelineNotFound from .pipeline import ( AudioSettings, @@ -45,16 +53,29 @@ __all__ = ( CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.Schema( - {vol.Optional("debug_recording_dir"): str}, + { + vol.Optional(CONF_DEBUG_RECORDING_DIR): str, + vol.Optional( + CONF_PIPELINE_TIMEOUT, default=DEFAULT_PIPELINE_TIMEOUT + ): vol.Any(float, int), + vol.Optional( + CONF_WAKE_WORD_TIMEOUT, default=DEFAULT_WAKE_WORD_TIMEOUT + ): vol.Any(float, int), + }, ) }, extra=vol.ALLOW_EXTRA, ) +DEFAULT_CONFIG = { + CONF_PIPELINE_TIMEOUT: DEFAULT_PIPELINE_TIMEOUT, + CONF_WAKE_WORD_TIMEOUT: DEFAULT_WAKE_WORD_TIMEOUT, +} + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Assist pipeline integration.""" - hass.data[DATA_CONFIG] = config.get(DOMAIN, {}) + hass.data[DATA_CONFIG] = config.get(DOMAIN, DEFAULT_CONFIG) await async_setup_pipeline_store(hass) async_register_websocket_api(hass) diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index e21d9003a69..1fb70cd3ee3 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -2,3 +2,11 @@ DOMAIN = "assist_pipeline" DATA_CONFIG = f"{DOMAIN}.config" + +CONF_PIPELINE_TIMEOUT = "pipeline_timeout" +DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds + +CONF_WAKE_WORD_TIMEOUT = "wake_word_timeout" +DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds + +CONF_DEBUG_RECORDING_DIR = "debug_recording_dir" diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 76444fb2436..ee8d3016743 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -48,7 +48,7 @@ from homeassistant.util import ( ) from homeassistant.util.limited_size_dict import LimitedSizeDict -from .const import DATA_CONFIG, DOMAIN +from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN from .error import ( IntentRecognitionError, PipelineError, @@ -603,6 +603,8 @@ class PipelineRun: ) ) + wake_word_settings = self.wake_word_settings or WakeWordSettings() + # Remove language since it doesn't apply to wake words yet metadata_dict.pop("language", None) @@ -612,6 +614,7 @@ class PipelineRun: { "entity_id": self.wake_word_entity_id, "metadata": metadata_dict, + "timeout": wake_word_settings.timeout or 0, }, ) ) @@ -619,8 +622,6 @@ class PipelineRun: if self.debug_recording_queue is not None: self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}") - wake_word_settings = self.wake_word_settings or WakeWordSettings() - wake_word_vad: VoiceActivityTimeout | None = None if (wake_word_settings.timeout is not None) and ( wake_word_settings.timeout > 0 @@ -1032,7 +1033,7 @@ class PipelineRun: # Directory to save audio for each pipeline run. # Configured in YAML for assist_pipeline. if debug_recording_dir := self.hass.data[DATA_CONFIG].get( - "debug_recording_dir" + CONF_DEBUG_RECORDING_DIR ): if device_id is None: # // diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 798843ea6e3..7634526f56a 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.util import language as language_util -from .const import DOMAIN +from .const import CONF_PIPELINE_TIMEOUT, CONF_WAKE_WORD_TIMEOUT, DATA_CONFIG, DOMAIN from .error import PipelineNotFound from .pipeline import ( AudioSettings, @@ -30,9 +30,6 @@ from .pipeline import ( async_get_pipeline, ) -DEFAULT_TIMEOUT = 60 * 5 # seconds -DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds - _LOGGER = logging.getLogger(__name__) @@ -117,7 +114,8 @@ async def websocket_run( ) return - timeout = msg.get("timeout", DEFAULT_TIMEOUT) + config = hass.data[DATA_CONFIG] + timeout = msg.get("timeout", config[CONF_PIPELINE_TIMEOUT]) start_stage = PipelineStage(msg["start_stage"]) end_stage = PipelineStage(msg["end_stage"]) handler_id: int | None = None @@ -139,7 +137,7 @@ async def websocket_run( if start_stage == PipelineStage.WAKE_WORD: wake_word_settings = WakeWordSettings( - timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), + timeout=msg["input"].get("timeout", config[CONF_WAKE_WORD_TIMEOUT]), audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0), ) diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 3f0582f2bfb..e822759d208 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -285,6 +285,7 @@ 'format': , 'sample_rate': , }), + 'timeout': 0, }), 'type': , }), @@ -396,6 +397,7 @@ 'format': , 'sample_rate': , }), + 'timeout': 0, }), 'type': , }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 7cecf9fed40..2bb87ddf7d6 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -373,6 +373,7 @@ 'format': 'wav', 'sample_rate': 16000, }), + 'timeout': 0, }) # --- # name: test_audio_pipeline_with_wake_word_no_timeout.2 @@ -474,6 +475,7 @@ 'format': 'wav', 'sample_rate': 16000, }), + 'timeout': 1, }) # --- # name: test_audio_pipeline_with_wake_word_timeout.2 diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 98ecae628f1..a98858a1bce 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -10,6 +10,10 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import assist_pipeline, stt +from homeassistant.components.assist_pipeline.const import ( + CONF_DEBUG_RECORDING_DIR, + DOMAIN, +) from homeassistant.core import Context, HomeAssistant from homeassistant.setup import async_setup_component @@ -395,8 +399,8 @@ async def test_pipeline_save_audio( temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, - "assist_pipeline", - {"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, + DOMAIN, + {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) pipeline = assist_pipeline.async_get_pipeline(hass) @@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id( temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, - "assist_pipeline", - {"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, + DOMAIN, + {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): @@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error( temp_dir = Path(temp_dir_str) assert await async_setup_component( hass, - "assist_pipeline", - {"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, + DOMAIN, + {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}}, ) def event_callback(event: assist_pipeline.PipelineEvent): diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index f995a0d3577..d862b7df975 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -4,10 +4,15 @@ from unittest.mock import ANY, patch from syrupy.assertion import SnapshotAssertion -from homeassistant.components.assist_pipeline.const import DOMAIN +from homeassistant.components.assist_pipeline.const import ( + CONF_PIPELINE_TIMEOUT, + CONF_WAKE_WORD_TIMEOUT, + DOMAIN, +) from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from homeassistant.setup import async_setup_component from tests.typing import WebSocketGenerator @@ -1805,3 +1810,43 @@ async def test_audio_pipeline_with_enhancements( msg = await client.receive_json() assert msg["success"] assert msg["result"] == {"events": events} + + +async def test_config_timeouts( + hass: HomeAssistant, + init_supporting_components, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test changing timeouts via YAML config.""" + assert await async_setup_component( + hass, + DOMAIN, + {DOMAIN: {CONF_PIPELINE_TIMEOUT: 10, CONF_WAKE_WORD_TIMEOUT: 1}}, + ) + + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": {"sample_rate": 16000}, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"], msg + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"]["runner_data"]["timeout"] == 10 + + # wake_word + msg = await client.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"]["timeout"] == 1