Remove cloud details from assist pipeline (#105687)

* Remove cloud details from assist pipeline

* Update assist pipeline tests

* Update cloud tests
This commit is contained in:
Martin Hjelmare 2023-12-14 10:15:59 +01:00 committed by GitHub
parent 82f0b28e89
commit 2e448d2d13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 66 deletions

View file

@ -115,6 +115,7 @@ async def _async_resolve_default_pipeline_settings(
hass: HomeAssistant, hass: HomeAssistant,
stt_engine_id: str | None, stt_engine_id: str | None,
tts_engine_id: str | None, tts_engine_id: str | None,
pipeline_name: str,
) -> dict[str, str | None]: ) -> dict[str, str | None]:
"""Resolve settings for a default pipeline. """Resolve settings for a default pipeline.
@ -123,7 +124,6 @@ async def _async_resolve_default_pipeline_settings(
""" """
conversation_language = "en" conversation_language = "en"
pipeline_language = "en" pipeline_language = "en"
pipeline_name = "Home Assistant"
stt_engine = None stt_engine = None
stt_language = None stt_language = None
tts_engine = None tts_engine = None
@ -195,9 +195,6 @@ async def _async_resolve_default_pipeline_settings(
) )
tts_engine_id = None tts_engine_id = None
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
pipeline_name = "Home Assistant Cloud"
return { return {
"conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": conversation_language, "conversation_language": conversation_language,
@ -221,12 +218,17 @@ async def _async_create_default_pipeline(
The default pipeline will use the homeassistant conversation agent and the The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines. default stt / tts engines.
""" """
pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None) pipeline_settings = await _async_resolve_default_pipeline_settings(
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
)
return await pipeline_store.async_create_item(pipeline_settings) return await pipeline_store.async_create_item(pipeline_settings)
async def async_create_default_pipeline( async def async_create_default_pipeline(
hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str hass: HomeAssistant,
stt_engine_id: str,
tts_engine_id: str,
pipeline_name: str,
) -> Pipeline | None: ) -> Pipeline | None:
"""Create a pipeline with default settings. """Create a pipeline with default settings.
@ -236,7 +238,7 @@ async def async_create_default_pipeline(
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store pipeline_store = pipeline_data.pipeline_store
pipeline_settings = await _async_resolve_default_pipeline_settings( pipeline_settings = await _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
) )
if ( if (
pipeline_settings["stt_engine"] != stt_engine_id pipeline_settings["stt_engine"] != stt_engine_id

View file

@ -232,7 +232,10 @@ class CloudLoginView(HomeAssistantView):
new_cloud_pipeline_id: str | None = None new_cloud_pipeline_id: str | None = None
if (cloud_assist_pipeline(hass)) is None: if (cloud_assist_pipeline(hass)) is None:
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline( if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
hass, DOMAIN, DOMAIN hass,
stt_engine_id=DOMAIN,
tts_engine_id=DOMAIN,
pipeline_name="Home Assistant Cloud",
): ):
new_cloud_pipeline_id = cloud_pipeline.id new_cloud_pipeline_id = cloud_pipeline.id
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id}) return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})

View file

@ -1,6 +1,6 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
from typing import Any from typing import Any
from unittest.mock import ANY, AsyncMock, patch from unittest.mock import ANY, patch
import pytest import pytest
@ -21,9 +21,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES from . import MANY_LANGUAGES
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider from .conftest import MockSttProvider, MockTTSProvider
from tests.common import MockModule, flush_store, mock_integration, mock_platform from tests.common import flush_store
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -237,13 +237,26 @@ async def test_create_default_pipeline(
store = pipeline_data.pipeline_store store = pipeline_data.pipeline_store
assert len(store.data) == 1 assert len(store.data) == 1
assert await async_create_default_pipeline(hass, "bla", "bla") is None assert (
assert await async_create_default_pipeline(hass, "test", "test") == Pipeline( await async_create_default_pipeline(
hass,
stt_engine_id="bla",
tts_engine_id="bla",
pipeline_name="Bla pipeline",
)
is None
)
assert await async_create_default_pipeline(
hass,
stt_engine_id="test",
tts_engine_id="test",
pipeline_name="Test pipeline",
) == Pipeline(
conversation_engine="homeassistant", conversation_engine="homeassistant",
conversation_language="en", conversation_language="en",
id=ANY, id=ANY,
language="en", language="en",
name="Home Assistant", name="Test pipeline",
stt_engine="test", stt_engine="test",
stt_language="en-US", stt_language="en-US",
tts_engine="test", tts_engine="test",
@ -465,53 +478,3 @@ async def test_default_pipeline_unsupported_tts_language(
wake_word_entity=None, wake_word_entity=None,
wake_word_id=None, wake_word_id=None,
) )
async def test_default_pipeline_cloud(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
mock_integration(hass, MockModule("cloud"))
mock_platform(
hass,
"cloud.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform(
hass,
"cloud.stt",
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="james_earl_jones",
wake_word_entity=None,
wake_word_id=None,
)

View file

@ -193,7 +193,12 @@ async def test_login_view_create_pipeline(
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
result = await req.json() result = await req.json()
assert result == {"success": True, "cloud_pipeline": "12345"} assert result == {"success": True, "cloud_pipeline": "12345"}
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud") create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="cloud",
tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud",
)
async def test_login_view_create_pipeline_fail( async def test_login_view_create_pipeline_fail(
@ -227,7 +232,12 @@ async def test_login_view_create_pipeline_fail(
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
result = await req.json() result = await req.json()
assert result == {"success": True, "cloud_pipeline": None} assert result == {"success": True, "cloud_pipeline": None}
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud") create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="cloud",
tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud",
)
async def test_login_view_random_exception( async def test_login_view_random_exception(