Add WS API for debugging previous assist_pipeline runs (#91541)
* Add WS API for debugging previous assist_pipeline runs * Improve typing
This commit is contained in:
parent
b597415b01
commit
0ecd23baee
7 changed files with 564 additions and 32 deletions
|
@ -24,6 +24,7 @@ from homeassistant.helpers.collection import (
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||||
|
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .error import (
|
from .error import (
|
||||||
|
@ -46,6 +47,8 @@ STORAGE_FIELDS = {
|
||||||
vol.Required("tts_engine"): str,
|
vol.Required("tts_engine"): str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STORED_PIPELINE_RUNS = 10
|
||||||
|
|
||||||
SAVE_DELAY = 10
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,14 +56,14 @@ async def async_get_pipeline(
|
||||||
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
||||||
) -> Pipeline | None:
|
) -> Pipeline | None:
|
||||||
"""Get a pipeline by id or create one for a language."""
|
"""Get a pipeline by id or create one for a language."""
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
|
||||||
if pipeline_id is not None:
|
if pipeline_id is not None:
|
||||||
return pipeline_store.data.get(pipeline_id)
|
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||||
|
|
||||||
# Construct a pipeline for the required/configured language
|
# Construct a pipeline for the required/configured language
|
||||||
language = language or hass.config.language
|
language = language or hass.config.language
|
||||||
return await pipeline_store.async_create_item(
|
return await pipeline_data.pipeline_store.async_create_item(
|
||||||
{
|
{
|
||||||
"name": language,
|
"name": language,
|
||||||
"language": language,
|
"language": language,
|
||||||
|
@ -171,6 +174,8 @@ class PipelineRun:
|
||||||
tts_engine: str | None = None
|
tts_engine: str | None = None
|
||||||
tts_options: dict | None = None
|
tts_options: dict | None = None
|
||||||
|
|
||||||
|
id: str = field(default_factory=ulid_util.ulid)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
self.language = self.pipeline.language or self.hass.config.language
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
|
@ -181,6 +186,23 @@ class PipelineRun:
|
||||||
):
|
):
|
||||||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||||
|
if self.pipeline.id not in pipeline_data.pipeline_runs:
|
||||||
|
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
|
||||||
|
size_limit=STORED_PIPELINE_RUNS
|
||||||
|
)
|
||||||
|
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def process_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Log an event and call listener."""
|
||||||
|
self.event_callback(event)
|
||||||
|
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||||
|
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
|
||||||
|
# This run has been evicted from the logged pipeline runs already
|
||||||
|
return
|
||||||
|
pipeline_data.pipeline_runs[self.pipeline.id][self.id].append(event)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Emit run start event."""
|
"""Emit run start event."""
|
||||||
data = {
|
data = {
|
||||||
|
@ -190,11 +212,11 @@ class PipelineRun:
|
||||||
if self.runner_data is not None:
|
if self.runner_data is not None:
|
||||||
data["runner_data"] = self.runner_data
|
data["runner_data"] = self.runner_data
|
||||||
|
|
||||||
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
|
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||||
|
|
||||||
def end(self) -> None:
|
def end(self) -> None:
|
||||||
"""Emit run end event."""
|
"""Emit run end event."""
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.RUN_END,
|
PipelineEventType.RUN_END,
|
||||||
)
|
)
|
||||||
|
@ -233,7 +255,7 @@ class PipelineRun:
|
||||||
|
|
||||||
engine = self.stt_provider.name
|
engine = self.stt_provider.name
|
||||||
|
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.STT_START,
|
PipelineEventType.STT_START,
|
||||||
{
|
{
|
||||||
|
@ -268,7 +290,7 @@ class PipelineRun:
|
||||||
code="stt-no-text-recognized", message="No text recognized"
|
code="stt-no-text-recognized", message="No text recognized"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.STT_END,
|
PipelineEventType.STT_END,
|
||||||
{
|
{
|
||||||
|
@ -306,7 +328,7 @@ class PipelineRun:
|
||||||
if self.intent_agent is None:
|
if self.intent_agent is None:
|
||||||
raise RuntimeError("Recognize intent was not prepared")
|
raise RuntimeError("Recognize intent was not prepared")
|
||||||
|
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_START,
|
PipelineEventType.INTENT_START,
|
||||||
{
|
{
|
||||||
|
@ -334,7 +356,7 @@ class PipelineRun:
|
||||||
|
|
||||||
_LOGGER.debug("conversation result %s", conversation_result)
|
_LOGGER.debug("conversation result %s", conversation_result)
|
||||||
|
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_END,
|
PipelineEventType.INTENT_END,
|
||||||
{"intent_output": conversation_result.as_dict()},
|
{"intent_output": conversation_result.as_dict()},
|
||||||
|
@ -379,7 +401,7 @@ class PipelineRun:
|
||||||
if self.tts_engine is None:
|
if self.tts_engine is None:
|
||||||
raise RuntimeError("Text to speech was not prepared")
|
raise RuntimeError("Text to speech was not prepared")
|
||||||
|
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.TTS_START,
|
PipelineEventType.TTS_START,
|
||||||
{
|
{
|
||||||
|
@ -412,7 +434,7 @@ class PipelineRun:
|
||||||
|
|
||||||
_LOGGER.debug("TTS result %s", tts_media)
|
_LOGGER.debug("TTS result %s", tts_media)
|
||||||
|
|
||||||
self.event_callback(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.TTS_END,
|
PipelineEventType.TTS_END,
|
||||||
{
|
{
|
||||||
|
@ -480,7 +502,7 @@ class PipelineInput:
|
||||||
await self.run.text_to_speech(tts_input)
|
await self.run.text_to_speech(tts_input)
|
||||||
|
|
||||||
except PipelineError as err:
|
except PipelineError as err:
|
||||||
self.run.event_callback(
|
self.run.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.ERROR,
|
PipelineEventType.ERROR,
|
||||||
{"code": err.code, "message": err.message},
|
{"code": err.code, "message": err.message},
|
||||||
|
@ -691,6 +713,14 @@ class PipelineStorageCollectionWebsocket(
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineData:
|
||||||
|
"""Store and debug data stored in hass.data."""
|
||||||
|
|
||||||
|
pipeline_runs: dict[str, LimitedSizeDict[str, list[PipelineEvent]]]
|
||||||
|
pipeline_store: PipelineStorageCollection
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
||||||
"""Set up the pipeline storage collection."""
|
"""Set up the pipeline storage collection."""
|
||||||
pipeline_store = PipelineStorageCollection(
|
pipeline_store = PipelineStorageCollection(
|
||||||
|
@ -700,4 +730,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
||||||
PipelineStorageCollectionWebsocket(
|
PipelineStorageCollectionWebsocket(
|
||||||
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
||||||
).async_setup(hass)
|
).async_setup(hass)
|
||||||
hass.data[DOMAIN] = pipeline_store
|
hass.data[DOMAIN] = PipelineData({}, pipeline_store)
|
||||||
|
|
|
@ -12,7 +12,9 @@ from homeassistant.components import stt, websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
|
PipelineData,
|
||||||
PipelineError,
|
PipelineError,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
|
@ -69,6 +71,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||||
|
websocket_api.async_register_command(hass, websocket_get_run)
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
|
@ -193,14 +197,82 @@ async def websocket_run(
|
||||||
async with async_timeout.timeout(timeout):
|
async with async_timeout.timeout(timeout):
|
||||||
await run_task
|
await run_task
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
connection.send_event(
|
pipeline_input.run.process_event(
|
||||||
msg["id"],
|
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.ERROR,
|
PipelineEventType.ERROR,
|
||||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if unregister_handler is not None:
|
if unregister_handler is not None:
|
||||||
# Unregister binary handler
|
# Unregister binary handler
|
||||||
unregister_handler()
|
unregister_handler()
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@websocket_api.require_admin
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_pipeline/pipeline_debug/list",
|
||||||
|
vol.Required("pipeline_id"): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def websocket_list_runs(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""List pipeline runs for which debug data is available."""
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = msg["pipeline_id"]
|
||||||
|
|
||||||
|
if pipeline_id not in pipeline_data.pipeline_runs:
|
||||||
|
connection.send_result(msg["id"], {"pipeline_runs": []})
|
||||||
|
return
|
||||||
|
|
||||||
|
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
|
||||||
|
|
||||||
|
connection.send_result(msg["id"], {"pipeline_runs": list(pipeline_runs)})
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@websocket_api.require_admin
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
|
||||||
|
vol.Required("pipeline_id"): str,
|
||||||
|
vol.Required("pipeline_run_id"): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def websocket_get_run(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Get debug data for a pipeline run."""
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = msg["pipeline_id"]
|
||||||
|
pipeline_run_id = msg["pipeline_run_id"]
|
||||||
|
|
||||||
|
if pipeline_id not in pipeline_data.pipeline_runs:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"],
|
||||||
|
websocket_api.const.ERR_NOT_FOUND,
|
||||||
|
f"pipeline_id {pipeline_id} not found",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
|
||||||
|
|
||||||
|
if pipeline_run_id not in pipeline_runs:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"],
|
||||||
|
websocket_api.const.ERR_NOT_FOUND,
|
||||||
|
f"pipeline_run_id {pipeline_run_id} not found",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"],
|
||||||
|
{"events": pipeline_runs[pipeline_run_id]},
|
||||||
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@ import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.json import ExtendedJSONEncoder
|
from homeassistant.helpers.json import ExtendedJSONEncoder
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||||
|
|
||||||
from . import websocket_api
|
from . import websocket_api
|
||||||
from .const import (
|
from .const import (
|
||||||
|
@ -24,7 +25,6 @@ from .const import (
|
||||||
DEFAULT_STORED_TRACES,
|
DEFAULT_STORED_TRACES,
|
||||||
)
|
)
|
||||||
from .models import ActionTrace, BaseTrace, RestoredTrace
|
from .models import ActionTrace, BaseTrace, RestoredTrace
|
||||||
from .utils import LimitedSizeDict
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,79 @@
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug
|
||||||
|
dict({
|
||||||
|
'language': 'en-US',
|
||||||
|
'pipeline': 'en-US',
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 30,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug.1
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'language': 'en-US',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug.2
|
||||||
|
dict({
|
||||||
|
'stt_output': dict({
|
||||||
|
'text': 'test transcript',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug.3
|
||||||
|
dict({
|
||||||
|
'engine': 'homeassistant',
|
||||||
|
'intent_input': 'test transcript',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug.4
|
||||||
|
dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'code': 'no_intent_match',
|
||||||
|
}),
|
||||||
|
'language': 'en-US',
|
||||||
|
'response_type': 'error',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': "Sorry, I couldn't understand that",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug.5
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_debug.6
|
||||||
|
dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
# name: test_intent_failed
|
# name: test_intent_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en-US',
|
||||||
|
|
|
@ -5,6 +5,7 @@ from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import (
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
STORAGE_VERSION,
|
STORAGE_VERSION,
|
||||||
|
PipelineData,
|
||||||
PipelineStorageCollection,
|
PipelineStorageCollection,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
@ -42,7 +43,8 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
||||||
]
|
]
|
||||||
pipeline_ids = []
|
pipeline_ids = []
|
||||||
|
|
||||||
store1: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store1 = pipeline_data.pipeline_store
|
||||||
for pipeline in pipelines:
|
for pipeline in pipelines:
|
||||||
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||||
assert len(store1.data) == 3
|
assert len(store1.data) == 3
|
||||||
|
@ -103,6 +105,7 @@ async def test_loading_datasets_from_storage(
|
||||||
|
|
||||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
assert len(store.data) == 3
|
assert len(store.data) == 3
|
||||||
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||||
|
|
|
@ -5,10 +5,7 @@ from unittest.mock import ANY, MagicMock, patch
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import (
|
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
||||||
Pipeline,
|
|
||||||
PipelineStorageCollection,
|
|
||||||
)
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
@ -21,6 +18,7 @@ async def test_text_only_pipeline(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
|
@ -40,20 +38,39 @@ async def test_text_only_pipeline(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# intent
|
# intent
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-start"
|
assert msg["event"]["type"] == "intent-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-end"
|
assert msg["event"]["type"] == "intent-end"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# run end
|
# run end
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-end"
|
assert msg["event"]["type"] == "run-end"
|
||||||
assert msg["event"]["data"] is None
|
assert msg["event"]["data"] is None
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_audio_pipeline(
|
async def test_audio_pipeline(
|
||||||
|
@ -63,6 +80,7 @@ async def test_audio_pipeline(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with audio input/output."""
|
"""Test events from a pipeline run with audio input/output."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
|
@ -84,11 +102,13 @@ async def test_audio_pipeline(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# stt
|
# stt
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-start"
|
assert msg["event"]["type"] == "stt-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(bytes([1]))
|
await client.send_bytes(bytes([1]))
|
||||||
|
@ -96,29 +116,50 @@ async def test_audio_pipeline(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-end"
|
assert msg["event"]["type"] == "stt-end"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# intent
|
# intent
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-start"
|
assert msg["event"]["type"] == "intent-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-end"
|
assert msg["event"]["type"] == "intent-end"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# text to speech
|
# text to speech
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "tts-start"
|
assert msg["event"]["type"] == "tts-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "tts-end"
|
assert msg["event"]["type"] == "tts-end"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# run end
|
# run end
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-end"
|
assert msg["event"]["type"] == "run-end"
|
||||||
assert msg["event"]["data"] is None
|
assert msg["event"]["data"] is None
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_intent_timeout(
|
async def test_intent_timeout(
|
||||||
|
@ -128,6 +169,7 @@ async def test_intent_timeout(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test partial pipeline run with conversation agent timeout."""
|
"""Test partial pipeline run with conversation agent timeout."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
async def sleepy_converse(*args, **kwargs):
|
async def sleepy_converse(*args, **kwargs):
|
||||||
|
@ -155,16 +197,34 @@ async def test_intent_timeout(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# intent
|
# intent
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-start"
|
assert msg["event"]["type"] == "intent-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# timeout error
|
# timeout error
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "error"
|
assert msg["event"]["type"] == "error"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_text_pipeline_timeout(
|
async def test_text_pipeline_timeout(
|
||||||
|
@ -174,6 +234,7 @@ async def test_text_pipeline_timeout(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test text-only pipeline run with immediate timeout."""
|
"""Test text-only pipeline run with immediate timeout."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
async def sleepy_run(*args, **kwargs):
|
async def sleepy_run(*args, **kwargs):
|
||||||
|
@ -201,6 +262,22 @@ async def test_text_pipeline_timeout(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "error"
|
assert msg["event"]["type"] == "error"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_intent_failed(
|
async def test_intent_failed(
|
||||||
|
@ -210,6 +287,7 @@ async def test_intent_failed(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test text-only pipeline run with conversation agent error."""
|
"""Test text-only pipeline run with conversation agent error."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
|
@ -233,16 +311,34 @@ async def test_intent_failed(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# intent start
|
# intent start
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-start"
|
assert msg["event"]["type"] == "intent-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# intent error
|
# intent error
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "error"
|
assert msg["event"]["type"] == "error"
|
||||||
assert msg["event"]["data"]["code"] == "intent-failed"
|
assert msg["event"]["data"]["code"] == "intent-failed"
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_audio_pipeline_timeout(
|
async def test_audio_pipeline_timeout(
|
||||||
|
@ -252,6 +348,7 @@ async def test_audio_pipeline_timeout(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test audio pipeline run with immediate timeout."""
|
"""Test audio pipeline run with immediate timeout."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
async def sleepy_run(*args, **kwargs):
|
async def sleepy_run(*args, **kwargs):
|
||||||
|
@ -281,6 +378,22 @@ async def test_audio_pipeline_timeout(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "error"
|
assert msg["event"]["type"] == "error"
|
||||||
assert msg["event"]["data"]["code"] == "timeout"
|
assert msg["event"]["data"]["code"] == "timeout"
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_stt_provider_missing(
|
async def test_stt_provider_missing(
|
||||||
|
@ -320,12 +433,13 @@ async def test_stt_stream_failed(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||||
|
events = []
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
|
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
|
||||||
new=MagicMock(side_effect=RuntimeError),
|
new=MagicMock(side_effect=RuntimeError),
|
||||||
):
|
):
|
||||||
client = await hass_ws_client(hass)
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
|
@ -345,11 +459,13 @@ async def test_stt_stream_failed(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# stt
|
# stt
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-start"
|
assert msg["event"]["type"] == "stt-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(b"1")
|
await client.send_bytes(b"1")
|
||||||
|
@ -358,6 +474,22 @@ async def test_stt_stream_failed(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "error"
|
assert msg["event"]["type"] == "error"
|
||||||
assert msg["event"]["data"]["code"] == "stt-stream-failed"
|
assert msg["event"]["data"]["code"] == "stt-stream-failed"
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_failed(
|
async def test_tts_failed(
|
||||||
|
@ -367,15 +499,15 @@ async def test_tts_failed(
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test pipeline run with text to speech error."""
|
"""Test pipeline run with text to speech error."""
|
||||||
|
events = []
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
new=MagicMock(return_value=RuntimeError),
|
new=MagicMock(return_value=RuntimeError),
|
||||||
):
|
):
|
||||||
await client.send_json(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"id": 5,
|
|
||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "tts",
|
"start_stage": "tts",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
|
@ -391,16 +523,34 @@ async def test_tts_failed(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# tts start
|
# tts start
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "tts-start"
|
assert msg["event"]["type"] == "tts-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
# tts error
|
# tts error
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "error"
|
assert msg["event"]["type"] == "error"
|
||||||
assert msg["event"]["data"]["code"] == "tts-failed"
|
assert msg["event"]["data"]["code"] == "tts-failed"
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_stage_order(
|
async def test_invalid_stage_order(
|
||||||
|
@ -428,7 +578,8 @@ async def test_add_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can add a pipeline."""
|
"""Test we can add a pipeline."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
|
@ -468,7 +619,8 @@ async def test_delete_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can delete a pipeline."""
|
"""Test we can delete a pipeline."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
|
@ -542,7 +694,8 @@ async def test_list_pipelines(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can list pipelines."""
|
"""Test we can list pipelines."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
|
||||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
@ -586,7 +739,8 @@ async def test_update_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can list pipelines."""
|
"""Test we can list pipelines."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
|
@ -660,7 +814,8 @@ async def test_set_preferred_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test updating the preferred pipeline."""
|
"""Test updating the preferred pipeline."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
|
@ -715,3 +870,202 @@ async def test_set_preferred_pipeline_wrong_id(
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["error"]["code"] == "not_found"
|
assert msg["error"]["code"] == "not_found"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_audio_pipeline_debug(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test debug listing events from a pipeline run with audio input/output."""
|
||||||
|
events = []
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 44100,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# End of audio stream (handler id + empty payload)
|
||||||
|
await client.send_bytes(bytes([1]))
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# intent
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# text to speech
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# run end
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] is None
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# Get the id of the pipeline
|
||||||
|
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(msg["result"]["pipelines"]) == 1
|
||||||
|
|
||||||
|
pipeline_id = msg["result"]["pipelines"][0]["id"]
|
||||||
|
|
||||||
|
# Get the id for the run
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": pipeline_id}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"pipeline_runs": [ANY]}
|
||||||
|
|
||||||
|
pipeline_run_id = msg["result"]["pipeline_runs"][0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_debug_list_runs_wrong_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
) -> None:
|
||||||
|
"""Test debug listing events from a pipeline."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": "blah"}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"pipeline_runs": []}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_debug_get_run_wrong_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
) -> None:
|
||||||
|
"""Test debug listing events from a pipeline."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": "blah",
|
||||||
|
"pipeline_run_id": "blah",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "pipeline_id blah not found",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_debug_get_run_wrong_pipeline_run(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
) -> None:
|
||||||
|
"""Test debug listing events from a pipeline."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "intent",
|
||||||
|
"end_stage": "intent",
|
||||||
|
"input": {"text": "Are the lights on?"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# consume events
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-end"
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
|
||||||
|
# Get the id of the pipeline
|
||||||
|
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(msg["result"]["pipelines"]) == 1
|
||||||
|
pipeline_id = msg["result"]["pipelines"][0]["id"]
|
||||||
|
|
||||||
|
# get debug data for the wrong run
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": "blah",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "pipeline_run_id blah not found",
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue