From 10606c4d1e896b347dc9f03f4fb3248dd3b97a39 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 18 Apr 2023 17:35:33 +0200 Subject: [PATCH] Use the preferred assist pipeline if none was specified (#91611) * Use the preferred assist pipeline if none was specified * Add test --- .../components/assist_pipeline/pipeline.py | 31 +++++++++++-------- .../assist_pipeline/test_pipeline.py | 22 +++++++++++++ 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 30cf52ea69c..380f4da8438 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -58,20 +58,25 @@ async def async_get_pipeline( """Get a pipeline by id or create one for a language.""" pipeline_data: PipelineData = hass.data[DOMAIN] - if pipeline_id is not None: - return pipeline_data.pipeline_store.data.get(pipeline_id) + if pipeline_id is None: + # A pipeline was not specified, use the preferred one + pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item() - # Construct a pipeline for the required/configured language - language = language or hass.config.language - return await pipeline_data.pipeline_store.async_create_item( - { - "name": language, - "language": language, - "stt_engine": None, # first engine - "conversation_engine": None, # first agent - "tts_engine": None, # first engine - } - ) + if pipeline_id is None: + # There's no preferred pipeline, construct a pipeline for the + # required/configured language + language = language or hass.config.language + return await pipeline_data.pipeline_store.async_create_item( + { + "name": language, + "language": language, + "stt_engine": None, # first engine + "conversation_engine": None, # first agent + "tts_engine": None, # first engine + } + ) + + return pipeline_data.pipeline_store.data.get(pipeline_id) class PipelineEventType(StrEnum): diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index f84fb2fa1d1..1517e5f53a3 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -7,6 +7,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_VERSION, PipelineData, PipelineStorageCollection, + async_get_pipeline, ) from homeassistant.core import HomeAssistant from homeassistant.helpers.storage import Store @@ -109,3 +110,24 @@ async def test_loading_datasets_from_storage( store = pipeline_data.pipeline_store assert len(store.data) == 3 assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" + + +async def test_get_pipeline(hass: HomeAssistant) -> None: + """Test async_get_pipeline.""" + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 0 + + # Test a pipeline is created + pipeline = await async_get_pipeline(hass, None) + assert len(store.data) == 1 + + # Test we get the same pipeline again + assert pipeline is await async_get_pipeline(hass, None) + assert len(store.data) == 1 + + # Test getting a specific pipeline + assert pipeline is await async_get_pipeline(hass, pipeline.id) + assert len(store.data) == 1