Use the preferred assist pipeline if none was specified (#91611)

* Use the preferred assist pipeline if none was specified

* Add test
This commit is contained in:
Erik Montnemery 2023-04-18 17:35:33 +02:00 committed by GitHub
parent 016e051db6
commit 10606c4d1e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 13 deletions

View file

@ -58,20 +58,25 @@ async def async_get_pipeline(
"""Get a pipeline by id or create one for a language."""
pipeline_data: PipelineData = hass.data[DOMAIN]
if pipeline_id is not None:
return pipeline_data.pipeline_store.data.get(pipeline_id)
if pipeline_id is None:
# A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
# Construct a pipeline for the required/configured language
language = language or hass.config.language
return await pipeline_data.pipeline_store.async_create_item(
{
"name": language,
"language": language,
"stt_engine": None, # first engine
"conversation_engine": None, # first agent
"tts_engine": None, # first engine
}
)
if pipeline_id is None:
# There's no preferred pipeline, construct a pipeline for the
# required/configured language
language = language or hass.config.language
return await pipeline_data.pipeline_store.async_create_item(
{
"name": language,
"language": language,
"stt_engine": None, # first engine
"conversation_engine": None, # first agent
"tts_engine": None, # first engine
}
)
return pipeline_data.pipeline_store.data.get(pipeline_id)
class PipelineEventType(StrEnum):

View file

@ -7,6 +7,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_VERSION,
PipelineData,
PipelineStorageCollection,
async_get_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store
@ -109,3 +110,24 @@ async def test_loading_datasets_from_storage(
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
async def test_get_pipeline(hass: HomeAssistant) -> None:
"""Test async_get_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 0
# Test a pipeline is created
pipeline = await async_get_pipeline(hass, None)
assert len(store.data) == 1
# Test we get the same pipeline again
assert pipeline is await async_get_pipeline(hass, None)
assert len(store.data) == 1
# Test getting a specific pipeline
assert pipeline is await async_get_pipeline(hass, pipeline.id)
assert len(store.data) == 1