From ed737f306b9f3a6aba4195cf9dcc26a60f23d19c Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 26 Apr 2023 12:53:58 +0200 Subject: [PATCH] Remove cloud assist pipeline setup from cloud client (#92056) --- homeassistant/components/cloud/__init__.py | 41 ++++++------ homeassistant/components/cloud/client.py | 35 +---------- homeassistant/components/cloud/http_api.py | 25 ++++++-- homeassistant/components/cloud/stt.py | 3 +- homeassistant/components/cloud/tts.py | 3 +- tests/components/cloud/test_client.py | 47 ++------------ tests/components/cloud/test_http_api.py | 73 ++++++++++++++-------- tests/components/cloud/test_init.py | 16 ++++- tests/components/cloud/test_tts.py | 9 ++- 9 files changed, 119 insertions(+), 133 deletions(-) diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index ebfaf6b0baa..0af85fe9d4d 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -238,25 +238,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await prefs.async_initialize() # Initialize Cloud - loaded = False - - async def _discover_platforms(): - """Discover platforms.""" - nonlocal loaded - - # Prevent multiple discovery - if loaded: - return - loaded = True - - await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config) - await async_load_platform(hass, Platform.STT, DOMAIN, {}, config) - await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config) - websession = async_get_clientsession(hass) - client = CloudClient( - hass, prefs, websession, alexa_conf, google_conf, _discover_platforms - ) + client = CloudClient(hass, prefs, websession, alexa_conf, google_conf) cloud = hass.data[DOMAIN] = Cloud(client, **kwargs) cloud.iot.register_on_connect(client.on_cloud_connected) @@ -288,6 +271,27 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: if subscription_info := await async_subscription_info(cloud): async_manage_legacy_subscription_issue(hass, subscription_info) + loaded = False + + async def _on_start(): + """Discover platforms.""" + nonlocal loaded + + # Prevent multiple discovery + if loaded: + return + loaded = True + + stt_platform_loaded = asyncio.Event() + tts_platform_loaded = asyncio.Event() + stt_info = {"platform_loaded": stt_platform_loaded} + tts_info = {"platform_loaded": tts_platform_loaded} + + await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config) + await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config) + await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config) + await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait()) + async def _on_connect(): """Handle cloud connect.""" async_dispatcher_send( @@ -304,6 +308,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Update preferences.""" await prefs.async_update(remote_domain=cloud.remote.instance_domain) + cloud.register_on_start(_on_start) cloud.iot.register_on_connect(_on_connect) cloud.iot.register_on_disconnect(_on_disconnect) cloud.register_on_initialized(_on_initialized) diff --git a/homeassistant/components/cloud/client.py b/homeassistant/components/cloud/client.py index 7a0fada7e15..631c0641b4f 100644 --- a/homeassistant/components/cloud/client.py +++ b/homeassistant/components/cloud/client.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine from http import HTTPStatus import logging from pathlib import Path @@ -11,13 +10,7 @@ from typing import Any import aiohttp from hass_nabucasa.client import CloudClient as Interface -from homeassistant.components import ( - assist_pipeline, - conversation, - google_assistant, - persistent_notification, - webhook, -) +from homeassistant.components import google_assistant, persistent_notification, webhook from homeassistant.components.alexa import ( errors as alexa_errors, smart_home as alexa_smart_home, @@ -43,7 +36,6 @@ class CloudClient(Interface): websession: aiohttp.ClientSession, alexa_user_config: dict[str, Any], google_user_config: dict[str, Any], - on_started_cb: Callable[[], Coroutine[Any, Any, None]], ) -> None: """Initialize client interface to Cloud.""" self._hass = hass @@ -56,10 +48,6 @@ class CloudClient(Interface): self._alexa_config_init_lock = asyncio.Lock() self._google_config_init_lock = asyncio.Lock() self._relayer_region: str | None = None - self._on_started_cb = on_started_cb - self.cloud_pipeline = self._cloud_assist_pipeline() - self.stt_platform_loaded = asyncio.Event() - self.tts_platform_loaded = asyncio.Event() @property def base_path(self) -> Path: @@ -148,22 +136,6 @@ class CloudClient(Interface): return self._google_config - def _cloud_assist_pipeline(self) -> str | None: - """Return the ID of a cloud-enabled assist pipeline or None.""" - for pipeline in assist_pipeline.async_get_pipelines(self._hass): - if ( - pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT - and pipeline.stt_engine == DOMAIN - and pipeline.tts_engine == DOMAIN - ): - return pipeline.id - return None - - async def create_cloud_assist_pipeline(self) -> None: - """Create a cloud-enabled assist pipeline.""" - await assist_pipeline.async_create_default_pipeline(self._hass, DOMAIN, DOMAIN) - self.cloud_pipeline = self._cloud_assist_pipeline() - async def on_cloud_connected(self) -> None: """When cloud is connected.""" is_new_user = await self.prefs.async_set_username(self.cloud.username) @@ -211,11 +183,6 @@ class CloudClient(Interface): async def cloud_started(self) -> None: """When cloud is started.""" - await self._on_started_cb() - await asyncio.gather( - self.stt_platform_loaded.wait(), - self.tts_platform_loaded.wait(), - ) async def cloud_stopped(self) -> None: """When the cloud is stopped.""" diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index 82eb64b3a3a..ab72af9fa5e 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -15,7 +15,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED from hass_nabucasa.voice import MAP_VOICE import voluptuous as vol -from homeassistant.components import websocket_api +from homeassistant.components import assist_pipeline, conversation, websocket_api from homeassistant.components.alexa import ( entities as alexa_entities, errors as alexa_errors, @@ -182,15 +182,28 @@ class CloudLoginView(HomeAssistantView): ) async def post(self, request, data): """Handle login request.""" + + def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: + """Return the ID of a cloud-enabled assist pipeline or None.""" + for pipeline in assist_pipeline.async_get_pipelines(hass): + if ( + pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT + and pipeline.stt_engine == DOMAIN + and pipeline.tts_engine == DOMAIN + ): + return pipeline.id + return None + hass = request.app["hass"] cloud = hass.data[DOMAIN] await cloud.login(data["email"], data["password"]) - if cloud.client.cloud_pipeline is None: - await cloud.client.create_cloud_assist_pipeline() - return self.json( - {"success": True, "cloud_pipeline": cloud.client.cloud_pipeline} - ) + if (cloud_pipeline_id := cloud_assist_pipeline(hass)) is None: + if cloud_pipeline := await assist_pipeline.async_create_default_pipeline( + hass, DOMAIN, DOMAIN + ): + cloud_pipeline_id = cloud_pipeline.id + return self.json({"success": True, "cloud_pipeline": cloud_pipeline_id}) class CloudLogoutView(HomeAssistantView): diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index 8ccb932c545..84e1e088d47 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -29,7 +29,8 @@ async def async_get_engine(hass, config, discovery_info=None): cloud: Cloud = hass.data[DOMAIN] cloud_provider = CloudProvider(cloud) - cloud.client.stt_platform_loaded.set() + if discovery_info is not None: + discovery_info["platform_loaded"].set() return cloud_provider diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index 58e918b9679..fea2ffca987 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -64,7 +64,8 @@ async def async_get_engine(hass, config, discovery_info=None): gender = config[ATTR_GENDER] cloud_provider = CloudProvider(cloud, language, gender) - cloud.client.tts_platform_loaded.set() + if discovery_info is not None: + discovery_info["platform_loaded"].set() return cloud_provider diff --git a/tests/components/cloud/test_client.py b/tests/components/cloud/test_client.py index 0e053941f51..d1e1a8ce112 100644 --- a/tests/components/cloud/test_client.py +++ b/tests/components/cloud/test_client.py @@ -6,11 +6,6 @@ import aiohttp from aiohttp import web import pytest -from homeassistant.components.assist_pipeline import ( - Pipeline, - async_get_pipeline, - async_get_pipelines, -) from homeassistant.components.cloud import DOMAIN from homeassistant.components.cloud.client import CloudClient from homeassistant.components.cloud.const import ( @@ -303,18 +298,14 @@ async def test_google_config_should_2fa( assert not gconf.should_2fa(state) -@patch( - "homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines", - return_value=[], -) -async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None: +async def test_set_username(hass: HomeAssistant) -> None: """Test we set username during login.""" prefs = MagicMock( alexa_enabled=False, google_enabled=False, async_set_username=AsyncMock(return_value=None), ) - client = CloudClient(hass, prefs, None, {}, {}, AsyncMock()) + client = CloudClient(hass, prefs, None, {}, {}) client.cloud = MagicMock(is_logged_in=True, username="mock-username") await client.on_cloud_connected() @@ -322,12 +313,8 @@ async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None: assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username" -@patch( - "homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines", - return_value=[], -) async def test_login_recovers_bad_internet( - async_get_pipelines, hass: HomeAssistant, caplog: pytest.LogCaptureFixture + hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test Alexa can recover bad auth.""" prefs = Mock( @@ -335,7 +322,7 @@ async def test_login_recovers_bad_internet( google_enabled=False, async_set_username=AsyncMock(return_value=None), ) - client = CloudClient(hass, prefs, None, {}, {}, AsyncMock()) + client = CloudClient(hass, prefs, None, {}, {}) client.cloud = Mock() client._alexa_config = Mock( async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError) @@ -367,29 +354,3 @@ async def test_system_msg(hass: HomeAssistant) -> None: assert response is None assert cloud.client.relayer_region == "xx-earth-616" - - -async def test_create_cloud_assist_pipeline( - hass: HomeAssistant, mock_cloud_setup, mock_cloud_login -) -> None: - """Test creating a cloud enabled assist pipeline.""" - cloud_client: CloudClient = hass.data[DOMAIN].client - await cloud_client.cloud_started() - assert cloud_client.cloud_pipeline is None - assert len(async_get_pipelines(hass)) == 1 - - await cloud_client.create_cloud_assist_pipeline() - assert cloud_client.cloud_pipeline is not None - assert len(async_get_pipelines(hass)) == 2 - assert async_get_pipeline(hass, cloud_client.cloud_pipeline) == Pipeline( - conversation_engine="homeassistant", - conversation_language="en", - id=cloud_client.cloud_pipeline, - language="en", - name="Home Assistant Cloud", - stt_engine="cloud", - stt_language="en-US", - tts_engine="cloud", - tts_language="en-US", - tts_voice="JennyNeural", - ) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 6dfca339182..de5778b72e7 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -104,45 +104,68 @@ async def test_google_actions_sync_fails( async def test_login_view(hass: HomeAssistant, cloud_client) -> None: - """Test logging in.""" - create_cloud_assist_pipeline_mock = AsyncMock() - hass.data["cloud"] = MagicMock( - login=AsyncMock(), - client=Mock( - cloud_pipeline="12345", - create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock, - ), - ) + """Test logging in when an assist pipeline is available.""" + hass.data["cloud"] = MagicMock(login=AsyncMock()) - req = await cloud_client.post( - "/api/cloud/login", json={"email": "my_username", "password": "my_password"} - ) + with patch( + "homeassistant.components.cloud.http_api.assist_pipeline.async_get_pipelines", + return_value=[ + Mock( + conversation_engine="homeassistant", + id="12345", + stt_engine=DOMAIN, + tts_engine=DOMAIN, + ) + ], + ), patch( + "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline", + ) as create_pipeline_mock: + req = await cloud_client.post( + "/api/cloud/login", json={"email": "my_username", "password": "my_password"} + ) assert req.status == HTTPStatus.OK result = await req.json() assert result == {"success": True, "cloud_pipeline": "12345"} - create_cloud_assist_pipeline_mock.assert_not_awaited() + create_pipeline_mock.assert_not_awaited() async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None: """Test logging in when no assist pipeline is available.""" - create_cloud_assist_pipeline_mock = AsyncMock() - hass.data["cloud"] = MagicMock( - login=AsyncMock(), - client=Mock( - cloud_pipeline=None, - create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock, - ), - ) + hass.data["cloud"] = MagicMock(login=AsyncMock()) - req = await cloud_client.post( - "/api/cloud/login", json={"email": "my_username", "password": "my_password"} - ) + with patch( + "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline", + return_value=AsyncMock(id="12345"), + ) as create_pipeline_mock: + req = await cloud_client.post( + "/api/cloud/login", json={"email": "my_username", "password": "my_password"} + ) + + assert req.status == HTTPStatus.OK + result = await req.json() + assert result == {"success": True, "cloud_pipeline": "12345"} + create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud") + + +async def test_login_view_create_pipeline_fail( + hass: HomeAssistant, cloud_client +) -> None: + """Test logging in when no assist pipeline is available.""" + hass.data["cloud"] = MagicMock(login=AsyncMock()) + + with patch( + "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline", + return_value=None, + ) as create_pipeline_mock: + req = await cloud_client.post( + "/api/cloud/login", json={"email": "my_username", "password": "my_password"} + ) assert req.status == HTTPStatus.OK result = await req.json() assert result == {"success": True, "cloud_pipeline": None} - create_cloud_assist_pipeline_mock.assert_awaited_once() + create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud") async def test_login_view_random_exception(cloud_client) -> None: diff --git a/tests/components/cloud/test_init.py b/tests/components/cloud/test_init.py index 9f6631f0cb9..bd0a4972241 100644 --- a/tests/components/cloud/test_init.py +++ b/tests/components/cloud/test_init.py @@ -155,17 +155,24 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None: assert len(hass.states.async_entity_ids("binary_sensor")) == 0 - await cl.client.cloud_started() + # The on_start callback discovers the binary sensor platform + assert "async_setup" in str(cl._on_start[-1]) + await cl._on_start[-1]() await hass.async_block_till_done() assert len(hass.states.async_entity_ids("binary_sensor")) == 1 with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load: - await cl.iot._on_connect[-1]() + await cl._on_start[-1]() await hass.async_block_till_done() assert len(mock_load.mock_calls) == 0 + assert len(cloud_states) == 1 + assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_CONNECTED + + await cl.iot._on_connect[-1]() + await hass.async_block_till_done() assert len(cloud_states) == 2 assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_CONNECTED @@ -177,6 +184,11 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None: assert len(cloud_states) == 3 assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_DISCONNECTED + await cl.iot._on_disconnect[-1]() + await hass.async_block_till_done() + assert len(cloud_states) == 4 + assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_DISCONNECTED + async def test_remote_ui_url(hass: HomeAssistant, mock_cloud_fixture) -> None: """Test getting remote ui url.""" diff --git a/tests/components/cloud/test_tts.py b/tests/components/cloud/test_tts.py index 6f31d78ac28..ba88ae2af2d 100644 --- a/tests/components/cloud/test_tts.py +++ b/tests/components/cloud/test_tts.py @@ -48,8 +48,9 @@ async def test_prefs_default_voice( """Test cloud provider uses the preferences.""" assert cloud_prefs.tts_default_voice == ("en-US", "female") + tts_info = {"platform_loaded": Mock()} provider_pref = await tts.async_get_engine( - Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} + Mock(data={const.DOMAIN: cloud_with_prefs}), None, tts_info ) provider_conf = await tts.async_get_engine( Mock(data={const.DOMAIN: cloud_with_prefs}), @@ -73,8 +74,9 @@ async def test_prefs_default_voice( async def test_provider_properties(cloud_with_prefs) -> None: """Test cloud provider.""" + tts_info = {"platform_loaded": Mock()} provider = await tts.async_get_engine( - Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} + Mock(data={const.DOMAIN: cloud_with_prefs}), None, tts_info ) assert provider.supported_options == ["gender", "voice", "audio_output"] assert "nl-NL" in provider.supported_languages @@ -85,8 +87,9 @@ async def test_provider_properties(cloud_with_prefs) -> None: async def test_get_tts_audio(cloud_with_prefs) -> None: """Test cloud provider.""" + tts_info = {"platform_loaded": Mock()} provider = await tts.async_get_engine( - Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} + Mock(data={const.DOMAIN: cloud_with_prefs}), None, tts_info ) assert provider.supported_options == ["gender", "voice", "audio_output"] assert "nl-NL" in provider.supported_languages