Compare commits
1 commit
dev
...
synesthesi
Author | SHA1 | Date | |
---|---|---|---|
|
fbd6cdf9bb |
8 changed files with 101 additions and 20 deletions
|
@ -9,7 +9,15 @@ from homeassistant.components import stt
|
|||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
CONF_PIPELINE_TIMEOUT,
|
||||
CONF_WAKE_WORD_TIMEOUT,
|
||||
DATA_CONFIG,
|
||||
DEFAULT_PIPELINE_TIMEOUT,
|
||||
DEFAULT_WAKE_WORD_TIMEOUT,
|
||||
DOMAIN,
|
||||
)
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
|
@ -45,16 +53,29 @@ __all__ = (
|
|||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{vol.Optional("debug_recording_dir"): str},
|
||||
{
|
||||
vol.Optional(CONF_DEBUG_RECORDING_DIR): str,
|
||||
vol.Optional(
|
||||
CONF_PIPELINE_TIMEOUT, default=DEFAULT_PIPELINE_TIMEOUT
|
||||
): vol.Any(float, int),
|
||||
vol.Optional(
|
||||
CONF_WAKE_WORD_TIMEOUT, default=DEFAULT_WAKE_WORD_TIMEOUT
|
||||
): vol.Any(float, int),
|
||||
},
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
CONF_PIPELINE_TIMEOUT: DEFAULT_PIPELINE_TIMEOUT,
|
||||
CONF_WAKE_WORD_TIMEOUT: DEFAULT_WAKE_WORD_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Assist pipeline integration."""
|
||||
hass.data[DATA_CONFIG] = config.get(DOMAIN, {})
|
||||
hass.data[DATA_CONFIG] = config.get(DOMAIN, DEFAULT_CONFIG)
|
||||
|
||||
await async_setup_pipeline_store(hass)
|
||||
async_register_websocket_api(hass)
|
||||
|
|
|
@ -2,3 +2,11 @@
|
|||
DOMAIN = "assist_pipeline"
|
||||
|
||||
DATA_CONFIG = f"{DOMAIN}.config"
|
||||
|
||||
CONF_PIPELINE_TIMEOUT = "pipeline_timeout"
|
||||
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
|
||||
|
||||
CONF_WAKE_WORD_TIMEOUT = "wake_word_timeout"
|
||||
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
|
||||
|
||||
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
||||
|
|
|
@ -48,7 +48,7 @@ from homeassistant.util import (
|
|||
)
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
|
||||
from .error import (
|
||||
IntentRecognitionError,
|
||||
PipelineError,
|
||||
|
@ -603,6 +603,8 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
wake_word_settings = self.wake_word_settings or WakeWordSettings()
|
||||
|
||||
# Remove language since it doesn't apply to wake words yet
|
||||
metadata_dict.pop("language", None)
|
||||
|
||||
|
@ -612,6 +614,7 @@ class PipelineRun:
|
|||
{
|
||||
"entity_id": self.wake_word_entity_id,
|
||||
"metadata": metadata_dict,
|
||||
"timeout": wake_word_settings.timeout or 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
@ -619,8 +622,6 @@ class PipelineRun:
|
|||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
|
||||
|
||||
wake_word_settings = self.wake_word_settings or WakeWordSettings()
|
||||
|
||||
wake_word_vad: VoiceActivityTimeout | None = None
|
||||
if (wake_word_settings.timeout is not None) and (
|
||||
wake_word_settings.timeout > 0
|
||||
|
@ -1032,7 +1033,7 @@ class PipelineRun:
|
|||
# Directory to save audio for each pipeline run.
|
||||
# Configured in YAML for assist_pipeline.
|
||||
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
|
||||
"debug_recording_dir"
|
||||
CONF_DEBUG_RECORDING_DIR
|
||||
):
|
||||
if device_id is None:
|
||||
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
||||
|
|
|
@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import CONF_PIPELINE_TIMEOUT, CONF_WAKE_WORD_TIMEOUT, DATA_CONFIG, DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
|
@ -30,9 +30,6 @@ from .pipeline import (
|
|||
async_get_pipeline,
|
||||
)
|
||||
|
||||
DEFAULT_TIMEOUT = 60 * 5 # seconds
|
||||
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -117,7 +114,8 @@ async def websocket_run(
|
|||
)
|
||||
return
|
||||
|
||||
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
||||
config = hass.data[DATA_CONFIG]
|
||||
timeout = msg.get("timeout", config[CONF_PIPELINE_TIMEOUT])
|
||||
start_stage = PipelineStage(msg["start_stage"])
|
||||
end_stage = PipelineStage(msg["end_stage"])
|
||||
handler_id: int | None = None
|
||||
|
@ -139,7 +137,7 @@ async def websocket_run(
|
|||
|
||||
if start_stage == PipelineStage.WAKE_WORD:
|
||||
wake_word_settings = WakeWordSettings(
|
||||
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
||||
timeout=msg["input"].get("timeout", config[CONF_WAKE_WORD_TIMEOUT]),
|
||||
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
|
||||
)
|
||||
|
||||
|
|
|
@ -285,6 +285,7 @@
|
|||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
'timeout': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
|
@ -396,6 +397,7 @@
|
|||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
'timeout': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
|
|
|
@ -373,6 +373,7 @@
|
|||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
'timeout': 0,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.2
|
||||
|
@ -474,6 +475,7 @@
|
|||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
'timeout': 1,
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout.2
|
||||
|
|
|
@ -10,6 +10,10 @@ import pytest
|
|||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
@ -395,8 +399,8 @@ async def test_pipeline_save_audio(
|
|||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass)
|
||||
|
@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id(
|
|||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
|
@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error(
|
|||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
|
|
|
@ -4,10 +4,15 @@ from unittest.mock import ANY, patch
|
|||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
CONF_PIPELINE_TIMEOUT,
|
||||
CONF_WAKE_WORD_TIMEOUT,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
@ -1805,3 +1810,43 @@ async def test_audio_pipeline_with_enhancements(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_config_timeouts(
|
||||
hass: HomeAssistant,
|
||||
init_supporting_components,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test changing timeouts via YAML config."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_PIPELINE_TIMEOUT: 10, CONF_WAKE_WORD_TIMEOUT: 1}},
|
||||
)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {"sample_rate": 16000},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"]["runner_data"]["timeout"] == 10
|
||||
|
||||
# wake_word
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-start"
|
||||
assert msg["event"]["data"]["timeout"] == 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue