Automaticially create an assist pipeline using cloud stt + tts (#91991)

* Automaticially create an assist pipeline using cloud stt + tts

* Return the id of the cloud enabled pipeline

* Wait for platforms to load

* Fix typing

* Fix startup race

* Update tests

* Create a cloud pipeline only when logging in

* Fix tests

* Tweak _async_resolve_default_pipeline_settings

* Improve assist_pipeline test coverage

* Improve cloud test coverage
This commit is contained in:
Erik Montnemery 2023-04-26 03:40:01 +02:00 committed by GitHub
parent 74e0341d83
commit 57a59d808b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 303 additions and 68 deletions

View file

@ -17,13 +17,17 @@ from .pipeline import (
PipelineInput, PipelineInput,
PipelineRun, PipelineRun,
PipelineStage, PipelineStage,
async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines,
async_setup_pipeline_store, async_setup_pipeline_store,
) )
from .websocket_api import async_register_websocket_api from .websocket_api import async_register_websocket_api
__all__ = ( __all__ = (
"DOMAIN", "DOMAIN",
"async_create_default_pipeline",
"async_get_pipelines",
"async_setup", "async_setup",
"async_pipeline_from_audio_stream", "async_pipeline_from_audio_stream",
"Pipeline", "Pipeline",

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncIterable, Callable from collections.abc import AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
import logging import logging
from typing import Any from typing import Any
@ -75,20 +75,22 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10 SAVE_DELAY = 10
async def _async_create_default_pipeline( async def _async_resolve_default_pipeline_settings(
hass: HomeAssistant, pipeline_store: PipelineStorageCollection hass: HomeAssistant,
) -> Pipeline: stt_engine_id: str | None,
"""Create a default pipeline. tts_engine_id: str | None,
) -> dict[str, str | None]:
"""Resolve settings for a default pipeline.
The default pipeline will use the homeassistant conversation agent and the The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines. default stt / tts engines if none are specified.
""" """
conversation_language = "en" conversation_language = "en"
pipeline_language = "en" pipeline_language = "en"
pipeline_name = "Home Assistant" pipeline_name = "Home Assistant"
stt_engine_id = None stt_engine = None
stt_language = None stt_language = None
tts_engine_id = None tts_engine = None
tts_language = None tts_language = None
tts_voice = None tts_voice = None
@ -104,12 +106,15 @@ async def _async_create_default_pipeline(
pipeline_language = hass.config.language pipeline_language = hass.config.language
conversation_language = conversation_languages[0] conversation_language = conversation_languages[0]
if (stt_engine_id := stt.async_default_engine(hass)) is not None and ( if stt_engine_id is None:
stt_engine := stt.async_get_speech_to_text_engine( stt_engine_id = stt.async_default_engine(hass)
hass,
stt_engine_id, if stt_engine_id is not None:
) stt_engine = stt.async_get_speech_to_text_engine(hass, stt_engine_id)
): if stt_engine is None:
stt_engine_id = None
if stt_engine:
stt_languages = language_util.matches( stt_languages = language_util.matches(
pipeline_language, pipeline_language,
stt_engine.supported_languages, stt_engine.supported_languages,
@ -125,12 +130,15 @@ async def _async_create_default_pipeline(
) )
stt_engine_id = None stt_engine_id = None
if (tts_engine_id := tts.async_default_engine(hass)) is not None and ( if tts_engine_id is None:
tts_engine := tts.get_engine_instance( tts_engine_id = tts.async_default_engine(hass)
hass,
tts_engine_id, if tts_engine_id is not None:
) tts_engine = tts.get_engine_instance(hass, tts_engine_id)
): if tts_engine is None:
tts_engine_id = None
if tts_engine:
tts_languages = language_util.matches( tts_languages = language_util.matches(
pipeline_language, pipeline_language,
tts_engine.supported_languages, tts_engine.supported_languages,
@ -152,19 +160,50 @@ async def _async_create_default_pipeline(
if stt_engine_id == "cloud" and tts_engine_id == "cloud": if stt_engine_id == "cloud" and tts_engine_id == "cloud":
pipeline_name = "Home Assistant Cloud" pipeline_name = "Home Assistant Cloud"
return await pipeline_store.async_create_item( return {
{ "conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_language": conversation_language,
"conversation_language": conversation_language, "language": hass.config.language,
"language": hass.config.language, "name": pipeline_name,
"name": pipeline_name, "stt_engine": stt_engine_id,
"stt_engine": stt_engine_id, "stt_language": stt_language,
"stt_language": stt_language, "tts_engine": tts_engine_id,
"tts_engine": tts_engine_id, "tts_language": tts_language,
"tts_language": tts_language, "tts_voice": tts_voice,
"tts_voice": tts_voice, }
}
async def _async_create_default_pipeline(
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
) -> Pipeline:
"""Create a default pipeline.
The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines.
"""
pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None)
return await pipeline_store.async_create_item(pipeline_settings)
async def async_create_default_pipeline(
hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str
) -> Pipeline | None:
"""Create a pipeline with default settings.
The default pipeline will use the homeassistant conversation agent and the
specified stt / tts engines.
"""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = await _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id
) )
if (
pipeline_settings["stt_engine"] != stt_engine_id
or pipeline_settings["tts_engine"] != tts_engine_id
):
return None
return await pipeline_store.async_create_item(pipeline_settings)
@callback @callback
@ -181,6 +220,14 @@ def async_get_pipeline(
return pipeline_data.pipeline_store.data.get(pipeline_id) return pipeline_data.pipeline_store.data.get(pipeline_id)
@callback
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
"""Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN]
return pipeline_data.pipeline_store.data.values()
class PipelineEventType(StrEnum): class PipelineEventType(StrEnum):
"""Event types emitted during a pipeline run.""" """Event types emitted during a pipeline run."""

View file

@ -238,9 +238,27 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await prefs.async_initialize() await prefs.async_initialize()
# Initialize Cloud # Initialize Cloud
loaded = False
async def _discover_platforms():
"""Discover platforms."""
nonlocal loaded
# Prevent multiple discovery
if loaded:
return
loaded = True
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, {}, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config)
websession = async_get_clientsession(hass) websession = async_get_clientsession(hass)
client = CloudClient(hass, prefs, websession, alexa_conf, google_conf) client = CloudClient(
hass, prefs, websession, alexa_conf, google_conf, _discover_platforms
)
cloud = hass.data[DOMAIN] = Cloud(client, **kwargs) cloud = hass.data[DOMAIN] = Cloud(client, **kwargs)
cloud.iot.register_on_connect(client.on_cloud_connected)
async def _shutdown(event): async def _shutdown(event):
"""Shutdown event.""" """Shutdown event."""
@ -262,8 +280,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
) )
loaded = False
async def async_startup_repairs(_=None) -> None: async def async_startup_repairs(_=None) -> None:
"""Create repair issues after startup.""" """Create repair issues after startup."""
if not cloud.is_logged_in: if not cloud.is_logged_in:
@ -272,23 +288,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if subscription_info := await async_subscription_info(cloud): if subscription_info := await async_subscription_info(cloud):
async_manage_legacy_subscription_issue(hass, subscription_info) async_manage_legacy_subscription_issue(hass, subscription_info)
async def _discover_platforms():
"""Discover platforms."""
nonlocal loaded
# Prevent multiple discovery
if loaded:
return
loaded = True
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, {}, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config)
async def _on_connect(): async def _on_connect():
"""Handle cloud connect.""" """Handle cloud connect."""
await _discover_platforms()
async_dispatcher_send( async_dispatcher_send(
hass, SIGNAL_CLOUD_CONNECTION_STATE, CloudConnectionState.CLOUD_CONNECTED hass, SIGNAL_CLOUD_CONNECTION_STATE, CloudConnectionState.CLOUD_CONNECTED
) )

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine
from http import HTTPStatus from http import HTTPStatus
import logging import logging
from pathlib import Path from pathlib import Path
@ -10,7 +11,13 @@ from typing import Any
import aiohttp import aiohttp
from hass_nabucasa.client import CloudClient as Interface from hass_nabucasa.client import CloudClient as Interface
from homeassistant.components import google_assistant, persistent_notification, webhook from homeassistant.components import (
assist_pipeline,
conversation,
google_assistant,
persistent_notification,
webhook,
)
from homeassistant.components.alexa import ( from homeassistant.components.alexa import (
errors as alexa_errors, errors as alexa_errors,
smart_home as alexa_smart_home, smart_home as alexa_smart_home,
@ -36,6 +43,7 @@ class CloudClient(Interface):
websession: aiohttp.ClientSession, websession: aiohttp.ClientSession,
alexa_user_config: dict[str, Any], alexa_user_config: dict[str, Any],
google_user_config: dict[str, Any], google_user_config: dict[str, Any],
on_started_cb: Callable[[], Coroutine[Any, Any, None]],
) -> None: ) -> None:
"""Initialize client interface to Cloud.""" """Initialize client interface to Cloud."""
self._hass = hass self._hass = hass
@ -48,6 +56,10 @@ class CloudClient(Interface):
self._alexa_config_init_lock = asyncio.Lock() self._alexa_config_init_lock = asyncio.Lock()
self._google_config_init_lock = asyncio.Lock() self._google_config_init_lock = asyncio.Lock()
self._relayer_region: str | None = None self._relayer_region: str | None = None
self._on_started_cb = on_started_cb
self.cloud_pipeline = self._cloud_assist_pipeline()
self.stt_platform_loaded = asyncio.Event()
self.tts_platform_loaded = asyncio.Event()
@property @property
def base_path(self) -> Path: def base_path(self) -> Path:
@ -136,8 +148,24 @@ class CloudClient(Interface):
return self._google_config return self._google_config
async def cloud_started(self) -> None: def _cloud_assist_pipeline(self) -> str | None:
"""When cloud is started.""" """Return the ID of a cloud-enabled assist pipeline or None."""
for pipeline in assist_pipeline.async_get_pipelines(self._hass):
if (
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None
async def create_cloud_assist_pipeline(self) -> None:
"""Create a cloud-enabled assist pipeline."""
await assist_pipeline.async_create_default_pipeline(self._hass, DOMAIN, DOMAIN)
self.cloud_pipeline = self._cloud_assist_pipeline()
async def on_cloud_connected(self) -> None:
"""When cloud is connected."""
is_new_user = await self.prefs.async_set_username(self.cloud.username) is_new_user = await self.prefs.async_set_username(self.cloud.username)
async def enable_alexa(_): async def enable_alexa(_):
@ -181,6 +209,14 @@ class CloudClient(Interface):
if tasks: if tasks:
await asyncio.gather(*(task(None) for task in tasks)) await asyncio.gather(*(task(None) for task in tasks))
async def cloud_started(self) -> None:
"""When cloud is started."""
await self._on_started_cb()
await asyncio.gather(
self.stt_platform_loaded.wait(),
self.tts_platform_loaded.wait(),
)
async def cloud_stopped(self) -> None: async def cloud_stopped(self) -> None:
"""When the cloud is stopped.""" """When the cloud is stopped."""

View file

@ -186,7 +186,11 @@ class CloudLoginView(HomeAssistantView):
cloud = hass.data[DOMAIN] cloud = hass.data[DOMAIN]
await cloud.login(data["email"], data["password"]) await cloud.login(data["email"], data["password"])
return self.json({"success": True}) if cloud.client.cloud_pipeline is None:
await cloud.client.create_cloud_assist_pipeline()
return self.json(
{"success": True, "cloud_pipeline": cloud.client.cloud_pipeline}
)
class CloudLogoutView(HomeAssistantView): class CloudLogoutView(HomeAssistantView):

View file

@ -3,7 +3,7 @@
"name": "Home Assistant Cloud", "name": "Home Assistant Cloud",
"after_dependencies": ["google_assistant", "alexa"], "after_dependencies": ["google_assistant", "alexa"],
"codeowners": ["@home-assistant/cloud"], "codeowners": ["@home-assistant/cloud"],
"dependencies": ["homeassistant", "http", "webhook"], "dependencies": ["assist_pipeline", "homeassistant", "http", "webhook"],
"documentation": "https://www.home-assistant.io/integrations/cloud", "documentation": "https://www.home-assistant.io/integrations/cloud",
"integration_type": "system", "integration_type": "system",
"iot_class": "cloud_push", "iot_class": "cloud_push",

View file

@ -28,7 +28,9 @@ async def async_get_engine(hass, config, discovery_info=None):
"""Set up Cloud speech component.""" """Set up Cloud speech component."""
cloud: Cloud = hass.data[DOMAIN] cloud: Cloud = hass.data[DOMAIN]
return CloudProvider(cloud) cloud_provider = CloudProvider(cloud)
cloud.client.stt_platform_loaded.set()
return cloud_provider
class CloudProvider(Provider): class CloudProvider(Provider):

View file

@ -63,7 +63,9 @@ async def async_get_engine(hass, config, discovery_info=None):
language = config[CONF_LANG] language = config[CONF_LANG]
gender = config[ATTR_GENDER] gender = config[ATTR_GENDER]
return CloudProvider(cloud, language, gender) cloud_provider = CloudProvider(cloud, language, gender)
cloud.client.tts_platform_loaded.set()
return cloud_provider
class CloudProvider(Provider): class CloudProvider(Provider):

View file

@ -1,6 +1,6 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import ANY, AsyncMock, patch
import pytest import pytest
@ -11,7 +11,9 @@ from homeassistant.components.assist_pipeline.pipeline import (
Pipeline, Pipeline,
PipelineData, PipelineData,
PipelineStorageCollection, PipelineStorageCollection,
async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
@ -143,6 +145,31 @@ async def test_loading_datasets_from_storage(
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
async def test_create_default_pipeline(
hass: HomeAssistant, init_supporting_components
) -> None:
"""Test async_create_default_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
assert await async_create_default_pipeline(hass, "bla", "bla") is None
assert await async_create_default_pipeline(hass, "test", "test") == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine="test",
stt_language="en-US",
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
)
async def test_get_pipeline(hass: HomeAssistant) -> None: async def test_get_pipeline(hass: HomeAssistant) -> None:
"""Test async_get_pipeline.""" """Test async_get_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {}) assert await async_setup_component(hass, "assist_pipeline", {})
@ -159,6 +186,31 @@ async def test_get_pipeline(hass: HomeAssistant) -> None:
assert pipeline is async_get_pipeline(hass, pipeline.id) assert pipeline is async_get_pipeline(hass, pipeline.id)
async def test_get_pipelines(hass: HomeAssistant) -> None:
"""Test async_get_pipelines."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
pipelines = async_get_pipelines(hass)
assert list(pipelines) == [
Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
)
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
("ha_language", "ha_country", "conv_language", "pipeline_language"), ("ha_language", "ha_country", "conv_language", "pipeline_language"),
[ [

View file

@ -8,6 +8,13 @@ from homeassistant.components.cloud import const, prefs
from . import mock_cloud, mock_cloud_prefs from . import mock_cloud, mock_cloud_prefs
# Prevent TTS cache from being created
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
init_cache_dir_side_effect,
mock_get_cache_files,
mock_init_cache_dir,
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_user_data(): def mock_user_data():

View file

@ -6,6 +6,11 @@ import aiohttp
from aiohttp import web from aiohttp import web
import pytest import pytest
from homeassistant.components.assist_pipeline import (
Pipeline,
async_get_pipeline,
async_get_pipelines,
)
from homeassistant.components.cloud import DOMAIN from homeassistant.components.cloud import DOMAIN
from homeassistant.components.cloud.client import CloudClient from homeassistant.components.cloud.client import CloudClient
from homeassistant.components.cloud.const import ( from homeassistant.components.cloud.const import (
@ -298,23 +303,31 @@ async def test_google_config_should_2fa(
assert not gconf.should_2fa(state) assert not gconf.should_2fa(state)
async def test_set_username(hass: HomeAssistant) -> None: @patch(
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
return_value=[],
)
async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None:
"""Test we set username during login.""" """Test we set username during login."""
prefs = MagicMock( prefs = MagicMock(
alexa_enabled=False, alexa_enabled=False,
google_enabled=False, google_enabled=False,
async_set_username=AsyncMock(return_value=None), async_set_username=AsyncMock(return_value=None),
) )
client = CloudClient(hass, prefs, None, {}, {}) client = CloudClient(hass, prefs, None, {}, {}, AsyncMock())
client.cloud = MagicMock(is_logged_in=True, username="mock-username") client.cloud = MagicMock(is_logged_in=True, username="mock-username")
await client.cloud_started() await client.on_cloud_connected()
assert len(prefs.async_set_username.mock_calls) == 1 assert len(prefs.async_set_username.mock_calls) == 1
assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username" assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username"
@patch(
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
return_value=[],
)
async def test_login_recovers_bad_internet( async def test_login_recovers_bad_internet(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture async_get_pipelines, hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test Alexa can recover bad auth.""" """Test Alexa can recover bad auth."""
prefs = Mock( prefs = Mock(
@ -322,12 +335,12 @@ async def test_login_recovers_bad_internet(
google_enabled=False, google_enabled=False,
async_set_username=AsyncMock(return_value=None), async_set_username=AsyncMock(return_value=None),
) )
client = CloudClient(hass, prefs, None, {}, {}) client = CloudClient(hass, prefs, None, {}, {}, AsyncMock())
client.cloud = Mock() client.cloud = Mock()
client._alexa_config = Mock( client._alexa_config = Mock(
async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError) async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError)
) )
await client.cloud_started() await client.on_cloud_connected()
assert len(client._alexa_config.async_enable_proactive_mode.mock_calls) == 1 assert len(client._alexa_config.async_enable_proactive_mode.mock_calls) == 1
assert "Unable to activate Alexa Report State" in caplog.text assert "Unable to activate Alexa Report State" in caplog.text
@ -354,3 +367,29 @@ async def test_system_msg(hass: HomeAssistant) -> None:
assert response is None assert response is None
assert cloud.client.relayer_region == "xx-earth-616" assert cloud.client.relayer_region == "xx-earth-616"
async def test_create_cloud_assist_pipeline(
hass: HomeAssistant, mock_cloud_setup, mock_cloud_login
) -> None:
"""Test creating a cloud enabled assist pipeline."""
cloud_client: CloudClient = hass.data[DOMAIN].client
await cloud_client.cloud_started()
assert cloud_client.cloud_pipeline is None
assert len(async_get_pipelines(hass)) == 1
await cloud_client.create_cloud_assist_pipeline()
assert cloud_client.cloud_pipeline is not None
assert len(async_get_pipelines(hass)) == 2
assert async_get_pipeline(hass, cloud_client.cloud_pipeline) == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=cloud_client.cloud_pipeline,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="JennyNeural",
)

View file

@ -105,7 +105,14 @@ async def test_google_actions_sync_fails(
async def test_login_view(hass: HomeAssistant, cloud_client) -> None: async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
"""Test logging in.""" """Test logging in."""
hass.data["cloud"] = MagicMock(login=AsyncMock()) create_cloud_assist_pipeline_mock = AsyncMock()
hass.data["cloud"] = MagicMock(
login=AsyncMock(),
client=Mock(
cloud_pipeline="12345",
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
),
)
req = await cloud_client.post( req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"} "/api/cloud/login", json={"email": "my_username", "password": "my_password"}
@ -113,7 +120,29 @@ async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
result = await req.json() result = await req.json()
assert result == {"success": True} assert result == {"success": True, "cloud_pipeline": "12345"}
create_cloud_assist_pipeline_mock.assert_not_awaited()
async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None:
"""Test logging in when no assist pipeline is available."""
create_cloud_assist_pipeline_mock = AsyncMock()
hass.data["cloud"] = MagicMock(
login=AsyncMock(),
client=Mock(
cloud_pipeline=None,
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
),
)
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": None}
create_cloud_assist_pipeline_mock.assert_awaited_once()
async def test_login_view_random_exception(cloud_client) -> None: async def test_login_view_random_exception(cloud_client) -> None:

View file

@ -2,6 +2,7 @@
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
from hass_nabucasa import Cloud
import pytest import pytest
from homeassistant.components import cloud from homeassistant.components import cloud
@ -134,9 +135,9 @@ async def test_setup_existing_cloud_user(
async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None: async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
"""Test cloud on connect triggers.""" """Test cloud on connect triggers."""
cl = hass.data["cloud"] cl: Cloud = hass.data["cloud"]
assert len(cl.iot._on_connect) == 3 assert len(cl.iot._on_connect) == 4
assert len(hass.states.async_entity_ids("binary_sensor")) == 0 assert len(hass.states.async_entity_ids("binary_sensor")) == 0
@ -152,6 +153,11 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
await cl.iot._on_connect[-1]() await cl.iot._on_connect[-1]()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass.states.async_entity_ids("binary_sensor")) == 0
await cl.client.cloud_started()
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids("binary_sensor")) == 1 assert len(hass.states.async_entity_ids("binary_sensor")) == 1
with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load: with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load:

View file

@ -511,6 +511,12 @@ async def test_get_engine_legacy(
TEST_DOMAIN, TEST_DOMAIN,
async_get_engine=AsyncMock(return_value=mock_provider), async_get_engine=AsyncMock(return_value=mock_provider),
) )
mock_stt_platform(
hass,
tmp_path,
"cloud",
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component( assert await async_setup_component(
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]} hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
) )