Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Michael Hansen
fbd6cdf9bb Allow changing timeouts via YAML 2023-10-04 10:16:41 -05:00
8 changed files with 101 additions and 20 deletions

View file

@ -9,7 +9,15 @@ from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import DATA_CONFIG, DOMAIN from .const import (
CONF_DEBUG_RECORDING_DIR,
CONF_PIPELINE_TIMEOUT,
CONF_WAKE_WORD_TIMEOUT,
DATA_CONFIG,
DEFAULT_PIPELINE_TIMEOUT,
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
)
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings, AudioSettings,
@ -45,16 +53,29 @@ __all__ = (
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.Schema( DOMAIN: vol.Schema(
{vol.Optional("debug_recording_dir"): str}, {
vol.Optional(CONF_DEBUG_RECORDING_DIR): str,
vol.Optional(
CONF_PIPELINE_TIMEOUT, default=DEFAULT_PIPELINE_TIMEOUT
): vol.Any(float, int),
vol.Optional(
CONF_WAKE_WORD_TIMEOUT, default=DEFAULT_WAKE_WORD_TIMEOUT
): vol.Any(float, int),
},
) )
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
DEFAULT_CONFIG = {
CONF_PIPELINE_TIMEOUT: DEFAULT_PIPELINE_TIMEOUT,
CONF_WAKE_WORD_TIMEOUT: DEFAULT_WAKE_WORD_TIMEOUT,
}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Assist pipeline integration.""" """Set up the Assist pipeline integration."""
hass.data[DATA_CONFIG] = config.get(DOMAIN, {}) hass.data[DATA_CONFIG] = config.get(DOMAIN, DEFAULT_CONFIG)
await async_setup_pipeline_store(hass) await async_setup_pipeline_store(hass)
async_register_websocket_api(hass) async_register_websocket_api(hass)

View file

@ -2,3 +2,11 @@
DOMAIN = "assist_pipeline" DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config" DATA_CONFIG = f"{DOMAIN}.config"
CONF_PIPELINE_TIMEOUT = "pipeline_timeout"
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
CONF_WAKE_WORD_TIMEOUT = "wake_word_timeout"
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"

View file

@ -48,7 +48,7 @@ from homeassistant.util import (
) )
from homeassistant.util.limited_size_dict import LimitedSizeDict from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DATA_CONFIG, DOMAIN from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
from .error import ( from .error import (
IntentRecognitionError, IntentRecognitionError,
PipelineError, PipelineError,
@ -603,6 +603,8 @@ class PipelineRun:
) )
) )
wake_word_settings = self.wake_word_settings or WakeWordSettings()
# Remove language since it doesn't apply to wake words yet # Remove language since it doesn't apply to wake words yet
metadata_dict.pop("language", None) metadata_dict.pop("language", None)
@ -612,6 +614,7 @@ class PipelineRun:
{ {
"entity_id": self.wake_word_entity_id, "entity_id": self.wake_word_entity_id,
"metadata": metadata_dict, "metadata": metadata_dict,
"timeout": wake_word_settings.timeout or 0,
}, },
) )
) )
@ -619,8 +622,6 @@ class PipelineRun:
if self.debug_recording_queue is not None: if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}") self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
wake_word_settings = self.wake_word_settings or WakeWordSettings()
wake_word_vad: VoiceActivityTimeout | None = None wake_word_vad: VoiceActivityTimeout | None = None
if (wake_word_settings.timeout is not None) and ( if (wake_word_settings.timeout is not None) and (
wake_word_settings.timeout > 0 wake_word_settings.timeout > 0
@ -1032,7 +1033,7 @@ class PipelineRun:
# Directory to save audio for each pipeline run. # Directory to save audio for each pipeline run.
# Configured in YAML for assist_pipeline. # Configured in YAML for assist_pipeline.
if debug_recording_dir := self.hass.data[DATA_CONFIG].get( if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
"debug_recording_dir" CONF_DEBUG_RECORDING_DIR
): ):
if device_id is None: if device_id is None:
# <debug_recording_dir>/<pipeline.name>/<run.id> # <debug_recording_dir>/<pipeline.name>/<run.id>

View file

@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util from homeassistant.util import language as language_util
from .const import DOMAIN from .const import CONF_PIPELINE_TIMEOUT, CONF_WAKE_WORD_TIMEOUT, DATA_CONFIG, DOMAIN
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings, AudioSettings,
@ -30,9 +30,6 @@ from .pipeline import (
async_get_pipeline, async_get_pipeline,
) )
DEFAULT_TIMEOUT = 60 * 5 # seconds
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -117,7 +114,8 @@ async def websocket_run(
) )
return return
timeout = msg.get("timeout", DEFAULT_TIMEOUT) config = hass.data[DATA_CONFIG]
timeout = msg.get("timeout", config[CONF_PIPELINE_TIMEOUT])
start_stage = PipelineStage(msg["start_stage"]) start_stage = PipelineStage(msg["start_stage"])
end_stage = PipelineStage(msg["end_stage"]) end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None handler_id: int | None = None
@ -139,7 +137,7 @@ async def websocket_run(
if start_stage == PipelineStage.WAKE_WORD: if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings( wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), timeout=msg["input"].get("timeout", config[CONF_WAKE_WORD_TIMEOUT]),
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0), audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
) )

View file

@ -285,6 +285,7 @@
'format': <AudioFormats.WAV: 'wav'>, 'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>, 'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}), }),
'timeout': 0,
}), }),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>, 'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}), }),
@ -396,6 +397,7 @@
'format': <AudioFormats.WAV: 'wav'>, 'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>, 'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}), }),
'timeout': 0,
}), }),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>, 'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}), }),

View file

@ -373,6 +373,7 @@
'format': 'wav', 'format': 'wav',
'sample_rate': 16000, 'sample_rate': 16000,
}), }),
'timeout': 0,
}) })
# --- # ---
# name: test_audio_pipeline_with_wake_word_no_timeout.2 # name: test_audio_pipeline_with_wake_word_no_timeout.2
@ -474,6 +475,7 @@
'format': 'wav', 'format': 'wav',
'sample_rate': 16000, 'sample_rate': 16000,
}), }),
'timeout': 1,
}) })
# --- # ---
# name: test_audio_pipeline_with_wake_word_timeout.2 # name: test_audio_pipeline_with_wake_word_timeout.2

View file

@ -10,6 +10,10 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt from homeassistant.components import assist_pipeline, stt
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -395,8 +399,8 @@ async def test_pipeline_save_audio(
temp_dir = Path(temp_dir_str) temp_dir = Path(temp_dir_str)
assert await async_setup_component( assert await async_setup_component(
hass, hass,
"assist_pipeline", DOMAIN,
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
) )
pipeline = assist_pipeline.async_get_pipeline(hass) pipeline = assist_pipeline.async_get_pipeline(hass)
@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id(
temp_dir = Path(temp_dir_str) temp_dir = Path(temp_dir_str)
assert await async_setup_component( assert await async_setup_component(
hass, hass,
"assist_pipeline", DOMAIN,
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
) )
def event_callback(event: assist_pipeline.PipelineEvent): def event_callback(event: assist_pipeline.PipelineEvent):
@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error(
temp_dir = Path(temp_dir_str) temp_dir = Path(temp_dir_str)
assert await async_setup_component( assert await async_setup_component(
hass, hass,
"assist_pipeline", DOMAIN,
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
) )
def event_callback(event: assist_pipeline.PipelineEvent): def event_callback(event: assist_pipeline.PipelineEvent):

View file

@ -4,10 +4,15 @@ from unittest.mock import ANY, patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.const import (
CONF_PIPELINE_TIMEOUT,
CONF_WAKE_WORD_TIMEOUT,
DOMAIN,
)
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -1805,3 +1810,43 @@ async def test_audio_pipeline_with_enhancements(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_config_timeouts(
hass: HomeAssistant,
init_supporting_components,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test changing timeouts via YAML config."""
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_PIPELINE_TIMEOUT: 10, CONF_WAKE_WORD_TIMEOUT: 1}},
)
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {"sample_rate": 16000},
}
)
# result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"]["runner_data"]["timeout"] == 10
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"]["timeout"] == 1