diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 6d00f26ee15..7f6bef6e3c0 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -31,6 +31,7 @@ from .pipeline import ( async_get_pipeline, async_get_pipelines, async_setup_pipeline_store, + async_update_pipeline, ) from .websocket_api import async_register_websocket_api @@ -40,6 +41,7 @@ __all__ = ( "async_get_pipelines", "async_setup", "async_pipeline_from_audio_stream", + "async_update_pipeline", "AudioSettings", "Pipeline", "PipelineEvent", diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 2ee1c71ccb8..71136dcdecb 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -43,6 +43,7 @@ from homeassistant.helpers.collection import ( ) from homeassistant.helpers.singleton import singleton from homeassistant.helpers.storage import Store +from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.util import ( dt as dt_util, language as language_util, @@ -276,6 +277,48 @@ def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]: return pipeline_data.pipeline_store.data.values() +async def async_update_pipeline( + hass: HomeAssistant, + pipeline: Pipeline, + *, + conversation_engine: str | UndefinedType = UNDEFINED, + conversation_language: str | UndefinedType = UNDEFINED, + language: str | UndefinedType = UNDEFINED, + name: str | UndefinedType = UNDEFINED, + stt_engine: str | None | UndefinedType = UNDEFINED, + stt_language: str | None | UndefinedType = UNDEFINED, + tts_engine: str | None | UndefinedType = UNDEFINED, + tts_language: str | None | UndefinedType = UNDEFINED, + tts_voice: str | None | UndefinedType = UNDEFINED, + wake_word_entity: str | None | UndefinedType = UNDEFINED, + wake_word_id: str | None | UndefinedType = UNDEFINED, +) -> None: + """Update a pipeline.""" + pipeline_data: PipelineData = hass.data[DOMAIN] + + updates: dict[str, Any] = pipeline.to_json() + updates.pop("id") + # Refactor this once we bump to Python 3.12 + # and have https://peps.python.org/pep-0692/ + for key, val in ( + ("conversation_engine", conversation_engine), + ("conversation_language", conversation_language), + ("language", language), + ("name", name), + ("stt_engine", stt_engine), + ("stt_language", stt_language), + ("tts_engine", tts_engine), + ("tts_language", tts_language), + ("tts_voice", tts_voice), + ("wake_word_entity", wake_word_entity), + ("wake_word_id", wake_word_id), + ): + if val is not UNDEFINED: + updates[key] = val + + await pipeline_data.pipeline_store.async_update_item(pipeline.id, updates) + + class PipelineEventType(StrEnum): """Event types emitted during a pipeline run.""" diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index 4dc242376d9..bf60ab9cc94 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -10,6 +10,7 @@ from hass_nabucasa import Cloud import voluptuous as vol from homeassistant.components import alexa, google_assistant +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_DESCRIPTION, CONF_MODE, @@ -51,6 +52,7 @@ from .const import ( CONF_SERVICEHANDLERS_SERVER, CONF_THINGTALK_SERVER, CONF_USER_POOL_ID, + DATA_PLATFORMS_SETUP, DOMAIN, MODE_DEV, MODE_PROD, @@ -61,6 +63,8 @@ from .subscription import async_subscription_info DEFAULT_MODE = MODE_PROD +PLATFORMS = [Platform.STT] + SERVICE_REMOTE_CONNECT = "remote_connect" SERVICE_REMOTE_DISCONNECT = "remote_disconnect" @@ -262,6 +266,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async_manage_legacy_subscription_issue(hass, subscription_info) loaded = False + stt_platform_loaded = asyncio.Event() + tts_platform_loaded = asyncio.Event() + hass.data[DATA_PLATFORMS_SETUP] = { + Platform.STT: stt_platform_loaded, + Platform.TTS: tts_platform_loaded, + } async def _on_start() -> None: """Discover platforms.""" @@ -272,15 +282,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return loaded = True - stt_platform_loaded = asyncio.Event() - tts_platform_loaded = asyncio.Event() - stt_info = {"platform_loaded": stt_platform_loaded} tts_info = {"platform_loaded": tts_platform_loaded} await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config) - await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config) await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config) - await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait()) + await tts_platform_loaded.wait() + + # The config entry should be loaded after the legacy tts platform is loaded + # to make sure that the tts integration is setup before we try to migrate + # old assist pipelines in the cloud stt entity. + await hass.config_entries.flow.async_init(DOMAIN, context={"source": "system"}) async def _on_connect() -> None: """Handle cloud connect.""" @@ -304,7 +315,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: cloud.register_on_initialized(_on_initialized) await cloud.initialize() - await http_api.async_setup(hass) + http_api.async_setup(hass) account_link.async_setup(hass) @@ -340,3 +351,19 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None: await cloud.remote.disconnect() cloud.client.prefs.async_listen_updates(remote_prefs_updated) + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT] + stt_platform_loaded.set() + + return True + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + + return unload_ok diff --git a/homeassistant/components/cloud/assist_pipeline.py b/homeassistant/components/cloud/assist_pipeline.py index 8054b3bd953..31e990cdb81 100644 --- a/homeassistant/components/cloud/assist_pipeline.py +++ b/homeassistant/components/cloud/assist_pipeline.py @@ -1,31 +1,48 @@ """Handle Cloud assist pipelines.""" +import asyncio + from homeassistant.components.assist_pipeline import ( async_create_default_pipeline, async_get_pipelines, async_setup_pipeline_store, + async_update_pipeline, ) from homeassistant.components.conversation import HOME_ASSISTANT_AGENT +from homeassistant.components.stt import DOMAIN as STT_DOMAIN +from homeassistant.const import Platform from homeassistant.core import HomeAssistant +import homeassistant.helpers.entity_registry as er -from .const import DOMAIN +from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: """Create a cloud assist pipeline.""" + # Wait for stt and tts platforms to set up before creating the pipeline. + platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP] + await asyncio.gather(*(event.wait() for event in platforms_setup.values())) # Make sure the pipeline store is loaded, needed because assist_pipeline # is an after dependency of cloud await async_setup_pipeline_store(hass) + entity_registry = er.async_get(hass) + new_stt_engine_id = entity_registry.async_get_entity_id( + STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID + ) + if new_stt_engine_id is None: + # If there's no cloud stt entity, we can't create a cloud pipeline. + return None + def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: """Return the ID of a cloud-enabled assist pipeline or None. - Check if a cloud pipeline already exists with - legacy cloud engine id. + Check if a cloud pipeline already exists with either + legacy or current cloud engine ids. """ for pipeline in async_get_pipelines(hass): if ( pipeline.conversation_engine == HOME_ASSISTANT_AGENT - and pipeline.stt_engine == DOMAIN + and pipeline.stt_engine in (DOMAIN, new_stt_engine_id) and pipeline.tts_engine == DOMAIN ): return pipeline.id @@ -34,7 +51,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: if (cloud_assist_pipeline(hass)) is not None or ( cloud_pipeline := await async_create_default_pipeline( hass, - stt_engine_id=DOMAIN, + stt_engine_id=new_stt_engine_id, tts_engine_id=DOMAIN, pipeline_name="Home Assistant Cloud", ) @@ -42,3 +59,27 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: return None return cloud_pipeline.id + + +async def async_migrate_cloud_pipeline_stt_engine( + hass: HomeAssistant, stt_engine_id: str +) -> None: + """Migrate the speech-to-text engine in the cloud assist pipeline.""" + # Migrate existing pipelines with cloud stt to use new cloud stt engine id. + # Added in 2024.01.0. Can be removed in 2025.01.0. + + # We need to make sure that tts is loaded before this migration. + # Assist pipeline will call default engine of tts when setting up the store. + # Wait for the tts platform loaded event here. + platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP] + await platforms_setup[Platform.TTS].wait() + + # Make sure the pipeline store is loaded, needed because assist_pipeline + # is an after dependency of cloud + await async_setup_pipeline_store(hass) + + pipelines = async_get_pipelines(hass) + for pipeline in pipelines: + if pipeline.stt_engine != DOMAIN: + continue + await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id) diff --git a/homeassistant/components/cloud/config_flow.py b/homeassistant/components/cloud/config_flow.py new file mode 100644 index 00000000000..a9554d97294 --- /dev/null +++ b/homeassistant/components/cloud/config_flow.py @@ -0,0 +1,23 @@ +"""Config flow for the Cloud integration.""" +from __future__ import annotations + +from typing import Any + +from homeassistant.config_entries import ConfigFlow +from homeassistant.data_entry_flow import FlowResult + +from .const import DOMAIN + + +class CloudConfigFlow(ConfigFlow, domain=DOMAIN): + """Handle a config flow for the Cloud integration.""" + + VERSION = 1 + + async def async_step_system( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the system step.""" + if self._async_current_entries(): + return self.async_abort(reason="single_instance_allowed") + return self.async_create_entry(title="Home Assistant Cloud", data={}) diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index 6e20978ec8d..db964607923 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -1,5 +1,6 @@ """Constants for the cloud component.""" DOMAIN = "cloud" +DATA_PLATFORMS_SETUP = "cloud_platforms_setup" REQUEST_TIMEOUT = 10 PREF_ENABLE_ALEXA = "alexa_enabled" @@ -64,3 +65,5 @@ MODE_DEV = "development" MODE_PROD = "production" DISPATCHER_REMOTE_UPDATE = "cloud_remote_update" + +STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text" diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index d01b0c29e06..849a1c99db9 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -28,7 +28,7 @@ from homeassistant.components.homeassistant import exposed_entities from homeassistant.components.http import HomeAssistantView, require_admin from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.util.location import async_detect_location_info @@ -66,7 +66,8 @@ _CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = { } -async def async_setup(hass: HomeAssistant) -> None: +@callback +def async_setup(hass: HomeAssistant) -> None: """Initialize the HTTP API.""" websocket_api.async_register_command(hass, websocket_cloud_status) websocket_api.async_register_command(hass, websocket_subscription) diff --git a/homeassistant/components/cloud/strings.json b/homeassistant/components/cloud/strings.json index 8195b78a01e..56fb3c0f5c9 100644 --- a/homeassistant/components/cloud/strings.json +++ b/homeassistant/components/cloud/strings.json @@ -1,4 +1,10 @@ { + "config": { + "step": {}, + "abort": { + "single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]" + } + }, "system_health": { "info": { "can_reach_cert_server": "Reach Certificate Server", diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index 7b6da8b7403..b652a36fa8a 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -13,37 +13,38 @@ from homeassistant.components.stt import ( AudioCodecs, AudioFormats, AudioSampleRates, - Provider, SpeechMetadata, SpeechResult, SpeechResultState, + SpeechToTextEntity, ) +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine from .client import CloudClient -from .const import DOMAIN +from .const import DOMAIN, STT_ENTITY_UNIQUE_ID _LOGGER = logging.getLogger(__name__) -async def async_get_engine( +async def async_setup_entry( hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, -) -> CloudProvider: - """Set up Cloud speech component.""" + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Home Assistant Cloud speech platform via config entry.""" cloud: Cloud[CloudClient] = hass.data[DOMAIN] - - cloud_provider = CloudProvider(cloud) - if discovery_info is not None: - discovery_info["platform_loaded"].set() - return cloud_provider + async_add_entities([CloudProviderEntity(cloud)]) -class CloudProvider(Provider): +class CloudProviderEntity(SpeechToTextEntity): """NabuCasa speech API provider.""" + _attr_name = "Home Assistant Cloud" + _attr_unique_id = STT_ENTITY_UNIQUE_ID + def __init__(self, cloud: Cloud[CloudClient]) -> None: """Home Assistant NabuCasa Speech to text.""" self.cloud = cloud @@ -78,6 +79,10 @@ class CloudProvider(Provider): """Return a list of supported channels.""" return [AudioChannels.CHANNEL_MONO] + async def async_added_to_hass(self) -> None: + """Run when entity is about to be added to hass.""" + await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id) + async def async_process_audio_stream( self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] ) -> SpeechResult: diff --git a/homeassistant/components/stt/legacy.py b/homeassistant/components/stt/legacy.py index cd5aef312ce..45f8ccefc68 100644 --- a/homeassistant/components/stt/legacy.py +++ b/homeassistant/components/stt/legacy.py @@ -29,9 +29,6 @@ _LOGGER = logging.getLogger(__name__) @callback def async_default_provider(hass: HomeAssistant) -> str | None: """Return the domain of the default provider.""" - if "cloud" in hass.data[DATA_PROVIDERS]: - return "cloud" - return next(iter(hass.data[DATA_PROVIDERS]), None) diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 597d355806f..35913df7400 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1,4 +1,5 @@ """Websocket tests for Voice Assistant integration.""" +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import ANY, patch @@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( async_create_default_pipeline, async_get_pipeline, async_get_pipelines, + async_update_pipeline, ) from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -26,6 +28,13 @@ from .conftest import MockSttProvider, MockTTSProvider from tests.common import flush_store +@pytest.fixture(autouse=True) +async def delay_save_fixture() -> AsyncGenerator[None, None]: + """Load the homeassistant integration.""" + with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0): + yield + + @pytest.fixture(autouse=True) async def load_homeassistant(hass) -> None: """Load the homeassistant integration.""" @@ -478,3 +487,125 @@ async def test_default_pipeline_unsupported_tts_language( wake_word_entity=None, wake_word_id=None, ) + + +async def test_update_pipeline( + hass: HomeAssistant, + hass_storage: dict[str, Any], +) -> None: + """Test async_update_pipeline.""" + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipelines = async_get_pipelines(hass) + pipelines = list(pipelines) + assert pipelines == [ + Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=ANY, + language="en", + name="Home Assistant", + stt_engine=None, + stt_language=None, + tts_engine=None, + tts_language=None, + tts_voice=None, + wake_word_entity=None, + wake_word_id=None, + ) + ] + + pipeline = pipelines[0] + await async_update_pipeline( + hass, + pipeline, + conversation_engine="homeassistant_1", + conversation_language="de", + language="de", + name="Home Assistant 1", + stt_engine="stt.test_1", + stt_language="de", + tts_engine="test_1", + tts_language="de", + tts_voice="test_voice", + wake_word_entity="wake_work.test_1", + wake_word_id="wake_word_id_1", + ) + + pipelines = async_get_pipelines(hass) + pipelines = list(pipelines) + pipeline = pipelines[0] + assert pipelines == [ + Pipeline( + conversation_engine="homeassistant_1", + conversation_language="de", + id=pipeline.id, + language="de", + name="Home Assistant 1", + stt_engine="stt.test_1", + stt_language="de", + tts_engine="test_1", + tts_language="de", + tts_voice="test_voice", + wake_word_entity="wake_work.test_1", + wake_word_id="wake_word_id_1", + ) + ] + assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1 + assert hass_storage[STORAGE_KEY]["data"]["items"][0] == { + "conversation_engine": "homeassistant_1", + "conversation_language": "de", + "id": pipeline.id, + "language": "de", + "name": "Home Assistant 1", + "stt_engine": "stt.test_1", + "stt_language": "de", + "tts_engine": "test_1", + "tts_language": "de", + "tts_voice": "test_voice", + "wake_word_entity": "wake_work.test_1", + "wake_word_id": "wake_word_id_1", + } + + await async_update_pipeline( + hass, + pipeline, + stt_engine="stt.test_2", + stt_language="en", + tts_engine="test_2", + tts_language="en", + ) + + pipelines = async_get_pipelines(hass) + pipelines = list(pipelines) + assert pipelines == [ + Pipeline( + conversation_engine="homeassistant_1", + conversation_language="de", + id=pipeline.id, + language="de", + name="Home Assistant 1", + stt_engine="stt.test_2", + stt_language="en", + tts_engine="test_2", + tts_language="en", + tts_voice="test_voice", + wake_word_entity="wake_work.test_1", + wake_word_id="wake_word_id_1", + ) + ] + assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1 + assert hass_storage[STORAGE_KEY]["data"]["items"][0] == { + "conversation_engine": "homeassistant_1", + "conversation_language": "de", + "id": pipeline.id, + "language": "de", + "name": "Home Assistant 1", + "stt_engine": "stt.test_2", + "stt_language": "en", + "tts_engine": "test_2", + "tts_language": "en", + "tts_voice": "test_voice", + "wake_word_entity": "wake_work.test_1", + "wake_word_id": "wake_word_id_1", + } diff --git a/tests/components/cloud/__init__.py b/tests/components/cloud/__init__.py index ea8c09706c5..22b84f032f6 100644 --- a/tests/components/cloud/__init__.py +++ b/tests/components/cloud/__init__.py @@ -1,7 +1,8 @@ """Tests for the cloud component.""" - from unittest.mock import AsyncMock, patch +from hass_nabucasa import Cloud + from homeassistant.components import cloud from homeassistant.components.cloud import const, prefs as cloud_prefs from homeassistant.setup import async_setup_component @@ -14,7 +15,7 @@ async def mock_cloud(hass, config=None): assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, cloud.DOMAIN, {"cloud": config or {}}) - cloud_inst = hass.data["cloud"] + cloud_inst: Cloud = hass.data["cloud"] with patch("hass_nabucasa.Cloud.run_executor", AsyncMock(return_value=None)): await cloud_inst.initialize() diff --git a/tests/components/cloud/test_config_flow.py b/tests/components/cloud/test_config_flow.py new file mode 100644 index 00000000000..ee4e37276dc --- /dev/null +++ b/tests/components/cloud/test_config_flow.py @@ -0,0 +1,40 @@ +"""Test the Home Assistant Cloud config flow.""" +from unittest.mock import patch + +from homeassistant.components.cloud.const import DOMAIN +from homeassistant.core import HomeAssistant + +from tests.common import MockConfigEntry + + +async def test_config_flow(hass: HomeAssistant) -> None: + """Test create cloud entry.""" + + with patch( + "homeassistant.components.cloud.async_setup", return_value=True + ) as mock_setup, patch( + "homeassistant.components.cloud.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "system"} + ) + assert result["type"] == "create_entry" + assert result["title"] == "Home Assistant Cloud" + assert result["data"] == {} + await hass.async_block_till_done() + + assert len(mock_setup.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_multiple_entries(hass: HomeAssistant) -> None: + """Test creating multiple cloud entries.""" + config_entry = MockConfigEntry(domain=DOMAIN) + config_entry.add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "system"} + ) + assert result["type"] == "abort" + assert result["reason"] == "single_instance_allowed" diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 3d7e6a69e3c..29930632691 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -46,6 +46,26 @@ PIPELINE_DATA_LEGACY = { "preferred_item": "12345", } +PIPELINE_DATA = { + "items": [ + { + "conversation_engine": "homeassistant", + "conversation_language": "language_1", + "id": "12345", + "language": "language_1", + "name": "Home Assistant Cloud", + "stt_engine": "stt.home_assistant_cloud", + "stt_language": "language_1", + "tts_engine": "cloud", + "tts_language": "language_1", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, + }, + ], + "preferred_item": "12345", +} + PIPELINE_DATA_OTHER = { "items": [ { @@ -127,7 +147,34 @@ async def test_google_actions_sync_fails( assert mock_request_sync.call_count == 1 -@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA_LEGACY]) +async def test_login_view_missing_stt_entity( + hass: HomeAssistant, + setup_cloud: None, + entity_registry: er.EntityRegistry, + hass_client: ClientSessionGenerator, +) -> None: + """Test logging in when the cloud stt entity is missing.""" + # Make sure that the cloud stt entity does not exist. + entity_registry.async_remove("stt.home_assistant_cloud") + await hass.async_block_till_done() + + cloud_client = await hass_client() + + # We assume the user needs to login again for some reason. + with patch( + "homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline", + ) as create_pipeline_mock: + req = await cloud_client.post( + "/api/cloud/login", json={"email": "my_username", "password": "my_password"} + ) + + assert req.status == HTTPStatus.OK + result = await req.json() + assert result == {"success": True, "cloud_pipeline": None} + create_pipeline_mock.assert_not_awaited() + + +@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA, PIPELINE_DATA_LEGACY]) async def test_login_view_existing_pipeline( hass: HomeAssistant, cloud: MagicMock, @@ -195,7 +242,7 @@ async def test_login_view_create_pipeline( assert result == {"success": True, "cloud_pipeline": "12345"} create_pipeline_mock.assert_awaited_once_with( hass, - stt_engine_id="cloud", + stt_engine_id="stt.home_assistant_cloud", tts_engine_id="cloud", pipeline_name="Home Assistant Cloud", ) @@ -234,7 +281,7 @@ async def test_login_view_create_pipeline_fail( assert result == {"success": True, "cloud_pipeline": None} create_pipeline_mock.assert_awaited_once_with( hass, - stt_engine_id="cloud", + stt_engine_id="stt.home_assistant_cloud", tts_engine_id="cloud", pipeline_name="Home Assistant Cloud", ) diff --git a/tests/components/cloud/test_stt.py b/tests/components/cloud/test_stt.py new file mode 100644 index 00000000000..666d8ae7d65 --- /dev/null +++ b/tests/components/cloud/test_stt.py @@ -0,0 +1,201 @@ +"""Test the speech-to-text platform for the cloud integration.""" +from collections.abc import AsyncGenerator +from copy import deepcopy +from http import HTTPStatus +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +from hass_nabucasa.voice import STTResponse, VoiceError +import pytest + +from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY +from homeassistant.components.cloud import DOMAIN +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component + +from tests.typing import ClientSessionGenerator + +PIPELINE_DATA = { + "items": [ + { + "conversation_engine": "conversation_engine_1", + "conversation_language": "language_1", + "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + "language": "language_1", + "name": "Home Assistant Cloud", + "stt_engine": "cloud", + "stt_language": "language_1", + "tts_engine": "cloud", + "tts_language": "language_1", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, + }, + { + "conversation_engine": "conversation_engine_2", + "conversation_language": "language_2", + "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", + "language": "language_2", + "name": "name_2", + "stt_engine": "stt_engine_2", + "stt_language": "language_2", + "tts_engine": "tts_engine_2", + "tts_language": "language_2", + "tts_voice": "The Voice", + "wake_word_entity": None, + "wake_word_id": None, + }, + { + "conversation_engine": "conversation_engine_3", + "conversation_language": "language_3", + "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", + "language": "language_3", + "name": "name_3", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": None, + "tts_voice": None, + "wake_word_entity": None, + "wake_word_id": None, + }, + ], + "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", +} + + +@pytest.fixture(autouse=True) +async def load_homeassistant(hass: HomeAssistant) -> None: + """Load the homeassistant integration.""" + assert await async_setup_component(hass, "homeassistant", {}) + + +@pytest.fixture(autouse=True) +async def delay_save_fixture() -> AsyncGenerator[None, None]: + """Load the homeassistant integration.""" + with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0): + yield + + +@pytest.mark.parametrize( + ("mock_process_stt", "expected_response_data"), + [ + ( + AsyncMock(return_value=STTResponse(True, "Turn the Kitchen Lights on")), + {"text": "Turn the Kitchen Lights on", "result": "success"}, + ), + (AsyncMock(side_effect=VoiceError("Boom!")), {"text": None, "result": "error"}), + ], +) +async def test_cloud_speech( + hass: HomeAssistant, + cloud: MagicMock, + hass_client: ClientSessionGenerator, + mock_process_stt: AsyncMock, + expected_response_data: dict[str, Any], +) -> None: + """Test cloud text-to-speech.""" + cloud.voice.process_stt = mock_process_stt + + assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) + await hass.async_block_till_done() + + on_start_callback = cloud.register_on_start.call_args[0][0] + await on_start_callback() + + state = hass.states.get("stt.home_assistant_cloud") + assert state + assert state.state == STATE_UNKNOWN + + client = await hass_client() + + response = await client.post( + "/api/stt/stt.home_assistant_cloud", + headers={ + "X-Speech-Content": ( + "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" + " language=de-DE" + ) + }, + data=b"Test", + ) + response_data = await response.json() + + assert mock_process_stt.call_count == 1 + assert ( + mock_process_stt.call_args.kwargs["content_type"] + == "audio/wav; codecs=audio/pcm; samplerate=16000" + ) + assert mock_process_stt.call_args.kwargs["language"] == "de-DE" + assert response.status == HTTPStatus.OK + assert response_data == expected_response_data + + state = hass.states.get("stt.home_assistant_cloud") + assert state + assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + + +async def test_migrating_pipelines( + hass: HomeAssistant, + cloud: MagicMock, + hass_client: ClientSessionGenerator, + hass_storage: dict[str, Any], +) -> None: + """Test migrating pipelines when cloud stt entity is added.""" + cloud.voice.process_stt = AsyncMock( + return_value=STTResponse(True, "Turn the Kitchen Lights on") + ) + hass_storage[STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": "assist_pipeline.pipelines", + "data": deepcopy(PIPELINE_DATA), + } + + assert await async_setup_component(hass, "assist_pipeline", {}) + assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) + await hass.async_block_till_done() + + on_start_callback = cloud.register_on_start.call_args[0][0] + await on_start_callback() + await hass.async_block_till_done() + + state = hass.states.get("stt.home_assistant_cloud") + assert state + assert state.state == STATE_UNKNOWN + + # The stt engine should be updated to the new cloud stt engine id. + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] + == "stt.home_assistant_cloud" + ) + + # The other items should stay the same. + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"] + == "conversation_engine_1" + ) + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"] + == "language_1" + ) + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["id"] + == "01GX8ZWBAQYWNB1XV3EXEZ75DY" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1" + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1" + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud" + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1" + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"] + == "Arnold Schwarzenegger" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None + assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1] + assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2] diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 4100df94b9e..9764451c5d5 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -121,12 +121,20 @@ class STTFlow(ConfigFlow): """Test flow.""" -@pytest.fixture(autouse=True) -def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]: - """Mock config flow.""" - mock_platform(hass, f"{TEST_DOMAIN}.config_flow") +@pytest.fixture(name="config_flow_test_domain") +def config_flow_test_domain_fixture() -> str: + """Test domain fixture.""" + return TEST_DOMAIN - with mock_config_flow(TEST_DOMAIN, STTFlow): + +@pytest.fixture(autouse=True) +def config_flow_fixture( + hass: HomeAssistant, config_flow_test_domain: str +) -> Generator[None, None, None]: + """Mock config flow.""" + mock_platform(hass, f"{config_flow_test_domain}.config_flow") + + with mock_config_flow(config_flow_test_domain, STTFlow): yield @@ -137,6 +145,7 @@ async def setup_fixture( request: pytest.FixtureRequest, ) -> MockProvider | MockProviderEntity: """Set up the test environment.""" + provider: MockProvider | MockProviderEntity if request.param == "mock_setup": provider = MockProvider() await mock_setup(hass, tmp_path, provider) @@ -166,7 +175,10 @@ async def mock_setup( async def mock_config_entry_setup( - hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity + hass: HomeAssistant, + tmp_path: Path, + mock_provider_entity: MockProviderEntity, + test_domain: str = TEST_DOMAIN, ) -> MockConfigEntry: """Set up a test provider via config entry.""" @@ -187,7 +199,7 @@ async def mock_config_entry_setup( mock_integration( hass, MockModule( - TEST_DOMAIN, + test_domain, async_setup_entry=async_setup_entry_init, async_unload_entry=async_unload_entry_init, ), @@ -201,9 +213,9 @@ async def mock_config_entry_setup( """Set up test stt platform via config entry.""" async_add_entities([mock_provider_entity]) - mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform) + mock_stt_entity_platform(hass, tmp_path, test_domain, async_setup_entry_platform) - config_entry = MockConfigEntry(domain=TEST_DOMAIN) + config_entry = MockConfigEntry(domain=test_domain) config_entry.add_to_hass(hass) assert await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() @@ -456,7 +468,11 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None: assert async_default_engine(hass) is None -async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None: +async def test_default_engine( + hass: HomeAssistant, + tmp_path: Path, + mock_provider: MockProvider, +) -> None: """Test async_default_engine.""" mock_stt_platform( hass, @@ -479,26 +495,31 @@ async def test_default_engine_entity( assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}" -async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None: +@pytest.mark.parametrize("config_flow_test_domain", ["new_test"]) +async def test_default_engine_prefer_provider( + hass: HomeAssistant, + tmp_path: Path, + mock_provider_entity: MockProviderEntity, + mock_provider: MockProvider, + config_flow_test_domain: str, +) -> None: """Test async_default_engine.""" - mock_stt_platform( - hass, - tmp_path, - TEST_DOMAIN, - async_get_engine=AsyncMock(return_value=mock_provider), - ) - mock_stt_platform( - hass, - tmp_path, - "cloud", - async_get_engine=AsyncMock(return_value=mock_provider), - ) - assert await async_setup_component( - hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]} + mock_provider_entity.url_path = "stt.new_test" + mock_provider_entity._attr_name = "New test" + + await mock_setup(hass, tmp_path, mock_provider) + await mock_config_entry_setup( + hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain ) await hass.async_block_till_done() - assert async_default_engine(hass) == "cloud" + entity_engine = async_get_speech_to_text_engine(hass, "stt.new_test") + assert entity_engine is not None + assert entity_engine.name == "New test" + provider_engine = async_get_speech_to_text_engine(hass, "test") + assert provider_engine is not None + assert provider_engine.name == "test" + assert async_default_engine(hass) == "test" async def test_get_engine_legacy(