Add WS API for debugging previous assist_pipeline runs (#91541)
* Add WS API for debugging previous assist_pipeline runs * Improve typing
This commit is contained in:
parent
b597415b01
commit
0ecd23baee
7 changed files with 564 additions and 32 deletions
|
@ -24,6 +24,7 @@ from homeassistant.helpers.collection import (
|
|||
)
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from .const import DOMAIN
|
||||
from .error import (
|
||||
|
@ -46,6 +47,8 @@ STORAGE_FIELDS = {
|
|||
vol.Required("tts_engine"): str,
|
||||
}
|
||||
|
||||
STORED_PIPELINE_RUNS = 10
|
||||
|
||||
SAVE_DELAY = 10
|
||||
|
||||
|
||||
|
@ -53,14 +56,14 @@ async def async_get_pipeline(
|
|||
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
||||
) -> Pipeline | None:
|
||||
"""Get a pipeline by id or create one for a language."""
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
||||
if pipeline_id is not None:
|
||||
return pipeline_store.data.get(pipeline_id)
|
||||
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||
|
||||
# Construct a pipeline for the required/configured language
|
||||
language = language or hass.config.language
|
||||
return await pipeline_store.async_create_item(
|
||||
return await pipeline_data.pipeline_store.async_create_item(
|
||||
{
|
||||
"name": language,
|
||||
"language": language,
|
||||
|
@ -171,6 +174,8 @@ class PipelineRun:
|
|||
tts_engine: str | None = None
|
||||
tts_options: dict | None = None
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
|
@ -181,6 +186,23 @@ class PipelineRun:
|
|||
):
|
||||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
if self.pipeline.id not in pipeline_data.pipeline_runs:
|
||||
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
|
||||
size_limit=STORED_PIPELINE_RUNS
|
||||
)
|
||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = []
|
||||
|
||||
@callback
|
||||
def process_event(self, event: PipelineEvent) -> None:
|
||||
"""Log an event and call listener."""
|
||||
self.event_callback(event)
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
|
||||
# This run has been evicted from the logged pipeline runs already
|
||||
return
|
||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id].append(event)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Emit run start event."""
|
||||
data = {
|
||||
|
@ -190,11 +212,11 @@ class PipelineRun:
|
|||
if self.runner_data is not None:
|
||||
data["runner_data"] = self.runner_data
|
||||
|
||||
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
|
||||
def end(self) -> None:
|
||||
"""Emit run end event."""
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.RUN_END,
|
||||
)
|
||||
|
@ -233,7 +255,7 @@ class PipelineRun:
|
|||
|
||||
engine = self.stt_provider.name
|
||||
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_START,
|
||||
{
|
||||
|
@ -268,7 +290,7 @@ class PipelineRun:
|
|||
code="stt-no-text-recognized", message="No text recognized"
|
||||
)
|
||||
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_END,
|
||||
{
|
||||
|
@ -306,7 +328,7 @@ class PipelineRun:
|
|||
if self.intent_agent is None:
|
||||
raise RuntimeError("Recognize intent was not prepared")
|
||||
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_START,
|
||||
{
|
||||
|
@ -334,7 +356,7 @@ class PipelineRun:
|
|||
|
||||
_LOGGER.debug("conversation result %s", conversation_result)
|
||||
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_END,
|
||||
{"intent_output": conversation_result.as_dict()},
|
||||
|
@ -379,7 +401,7 @@ class PipelineRun:
|
|||
if self.tts_engine is None:
|
||||
raise RuntimeError("Text to speech was not prepared")
|
||||
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_START,
|
||||
{
|
||||
|
@ -412,7 +434,7 @@ class PipelineRun:
|
|||
|
||||
_LOGGER.debug("TTS result %s", tts_media)
|
||||
|
||||
self.event_callback(
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_END,
|
||||
{
|
||||
|
@ -480,7 +502,7 @@ class PipelineInput:
|
|||
await self.run.text_to_speech(tts_input)
|
||||
|
||||
except PipelineError as err:
|
||||
self.run.event_callback(
|
||||
self.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": err.code, "message": err.message},
|
||||
|
@ -691,6 +713,14 @@ class PipelineStorageCollectionWebsocket(
|
|||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineData:
|
||||
"""Store and debug data stored in hass.data."""
|
||||
|
||||
pipeline_runs: dict[str, LimitedSizeDict[str, list[PipelineEvent]]]
|
||||
pipeline_store: PipelineStorageCollection
|
||||
|
||||
|
||||
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
||||
"""Set up the pipeline storage collection."""
|
||||
pipeline_store = PipelineStorageCollection(
|
||||
|
@ -700,4 +730,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
|||
PipelineStorageCollectionWebsocket(
|
||||
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
hass.data[DOMAIN] = pipeline_store
|
||||
hass.data[DOMAIN] = PipelineData({}, pipeline_store)
|
||||
|
|
|
@ -12,7 +12,9 @@ from homeassistant.components import stt, websocket_api
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import (
|
||||
PipelineData,
|
||||
PipelineError,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
|
@ -69,6 +71,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
),
|
||||
),
|
||||
)
|
||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||
websocket_api.async_register_command(hass, websocket_get_run)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
|
@ -193,14 +197,82 @@ async def websocket_run(
|
|||
async with async_timeout.timeout(timeout):
|
||||
await run_task
|
||||
except asyncio.TimeoutError:
|
||||
connection.send_event(
|
||||
msg["id"],
|
||||
pipeline_input.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||
),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if unregister_handler is not None:
|
||||
# Unregister binary handler
|
||||
unregister_handler()
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/pipeline_debug/list",
|
||||
vol.Required("pipeline_id"): str,
|
||||
}
|
||||
)
|
||||
def websocket_list_runs(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""List pipeline runs for which debug data is available."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = msg["pipeline_id"]
|
||||
|
||||
if pipeline_id not in pipeline_data.pipeline_runs:
|
||||
connection.send_result(msg["id"], {"pipeline_runs": []})
|
||||
return
|
||||
|
||||
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
|
||||
|
||||
connection.send_result(msg["id"], {"pipeline_runs": list(pipeline_runs)})
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
|
||||
vol.Required("pipeline_id"): str,
|
||||
vol.Required("pipeline_run_id"): str,
|
||||
}
|
||||
)
|
||||
def websocket_get_run(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Get debug data for a pipeline run."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = msg["pipeline_id"]
|
||||
pipeline_run_id = msg["pipeline_run_id"]
|
||||
|
||||
if pipeline_id not in pipeline_data.pipeline_runs:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
f"pipeline_id {pipeline_id} not found",
|
||||
)
|
||||
return
|
||||
|
||||
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
|
||||
|
||||
if pipeline_run_id not in pipeline_runs:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
f"pipeline_run_id {pipeline_run_id} not found",
|
||||
)
|
||||
return
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{"events": pipeline_runs[pipeline_run_id]},
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.json import ExtendedJSONEncoder
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from . import websocket_api
|
||||
from .const import (
|
||||
|
@ -24,7 +25,6 @@ from .const import (
|
|||
DEFAULT_STORED_TRACES,
|
||||
)
|
||||
from .models import ActionTrace, BaseTrace, RestoredTrace
|
||||
from .utils import LimitedSizeDict
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -72,6 +72,79 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.2
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.3
|
||||
dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.4
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en-US',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
|
|
|
@ -5,6 +5,7 @@ from homeassistant.components.assist_pipeline.const import DOMAIN
|
|||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
STORAGE_KEY,
|
||||
STORAGE_VERSION,
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -42,7 +43,8 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
|||
]
|
||||
pipeline_ids = []
|
||||
|
||||
store1: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store1 = pipeline_data.pipeline_store
|
||||
for pipeline in pipelines:
|
||||
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||
assert len(store1.data) == 3
|
||||
|
@ -103,6 +105,7 @@ async def test_loading_datasets_from_storage(
|
|||
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 3
|
||||
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
|
|
|
@ -5,10 +5,7 @@ from unittest.mock import ANY, MagicMock, patch
|
|||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
Pipeline,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
@ -21,6 +18,7 @@ async def test_text_only_pipeline(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
|
@ -40,20 +38,39 @@ async def test_text_only_pipeline(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_audio_pipeline(
|
||||
|
@ -63,6 +80,7 @@ async def test_audio_pipeline(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with audio input/output."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
|
@ -84,11 +102,13 @@ async def test_audio_pipeline(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([1]))
|
||||
|
@ -96,29 +116,50 @@ async def test_audio_pipeline(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# text to speech
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_intent_timeout(
|
||||
|
@ -128,6 +169,7 @@ async def test_intent_timeout(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test partial pipeline run with conversation agent timeout."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
async def sleepy_converse(*args, **kwargs):
|
||||
|
@ -155,16 +197,34 @@ async def test_intent_timeout(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_text_pipeline_timeout(
|
||||
|
@ -174,6 +234,7 @@ async def test_text_pipeline_timeout(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test text-only pipeline run with immediate timeout."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
async def sleepy_run(*args, **kwargs):
|
||||
|
@ -201,6 +262,22 @@ async def test_text_pipeline_timeout(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_intent_failed(
|
||||
|
@ -210,6 +287,7 @@ async def test_intent_failed(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test text-only pipeline run with conversation agent error."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
|
@ -233,16 +311,34 @@ async def test_intent_failed(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "intent-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_audio_pipeline_timeout(
|
||||
|
@ -252,6 +348,7 @@ async def test_audio_pipeline_timeout(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test audio pipeline run with immediate timeout."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
async def sleepy_run(*args, **kwargs):
|
||||
|
@ -281,6 +378,22 @@ async def test_audio_pipeline_timeout(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "timeout"
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_stt_provider_missing(
|
||||
|
@ -320,12 +433,13 @@ async def test_stt_stream_failed(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
|
||||
new=MagicMock(side_effect=RuntimeError),
|
||||
):
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
|
@ -345,11 +459,13 @@ async def test_stt_stream_failed(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(b"1")
|
||||
|
@ -358,6 +474,22 @@ async def test_stt_stream_failed(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "stt-stream-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_tts_failed(
|
||||
|
@ -367,15 +499,15 @@ async def test_tts_failed(
|
|||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test pipeline run with text to speech error."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
new=MagicMock(return_value=RuntimeError),
|
||||
):
|
||||
await client.send_json(
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
|
@ -391,16 +523,34 @@ async def test_tts_failed(
|
|||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# tts start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# tts error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "tts-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_invalid_stage_order(
|
||||
|
@ -428,7 +578,8 @@ async def test_add_pipeline(
|
|||
) -> None:
|
||||
"""Test we can add a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
|
@ -468,7 +619,8 @@ async def test_delete_pipeline(
|
|||
) -> None:
|
||||
"""Test we can delete a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
|
@ -542,7 +694,8 @@ async def test_list_pipelines(
|
|||
) -> None:
|
||||
"""Test we can list pipelines."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
|
@ -586,7 +739,8 @@ async def test_update_pipeline(
|
|||
) -> None:
|
||||
"""Test we can list pipelines."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
|
@ -660,7 +814,8 @@ async def test_set_preferred_pipeline(
|
|||
) -> None:
|
||||
"""Test updating the preferred pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
|
@ -715,3 +870,202 @@ async def test_set_preferred_pipeline_wrong_id(
|
|||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["error"]["code"] == "not_found"
|
||||
|
||||
|
||||
async def test_audio_pipeline_debug(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test debug listing events from a pipeline run with audio input/output."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 44100,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([1]))
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# text to speech
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
events.append(msg["event"])
|
||||
|
||||
# Get the id of the pipeline
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(msg["result"]["pipelines"]) == 1
|
||||
|
||||
pipeline_id = msg["result"]["pipelines"][0]["id"]
|
||||
|
||||
# Get the id for the run
|
||||
await client.send_json_auto_id(
|
||||
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": pipeline_id}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"pipeline_runs": [ANY]}
|
||||
|
||||
pipeline_run_id = msg["result"]["pipeline_runs"][0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_pipeline_debug_list_runs_wrong_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test debug listing events from a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": "blah"}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"pipeline_runs": []}
|
||||
|
||||
|
||||
async def test_pipeline_debug_get_run_wrong_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test debug listing events from a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": "blah",
|
||||
"pipeline_run_id": "blah",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "pipeline_id blah not found",
|
||||
}
|
||||
|
||||
|
||||
async def test_pipeline_debug_get_run_wrong_pipeline_run(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test debug listing events from a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "intent",
|
||||
"end_stage": "intent",
|
||||
"input": {"text": "Are the lights on?"},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# consume events
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
|
||||
# Get the id of the pipeline
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(msg["result"]["pipelines"]) == 1
|
||||
pipeline_id = msg["result"]["pipelines"][0]["id"]
|
||||
|
||||
# get debug data for the wrong run
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": "blah",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "pipeline_run_id blah not found",
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue