Automaticially create an assist pipeline using cloud stt + tts (#91991)
* Automaticially create an assist pipeline using cloud stt + tts * Return the id of the cloud enabled pipeline * Wait for platforms to load * Fix typing * Fix startup race * Update tests * Create a cloud pipeline only when logging in * Fix tests * Tweak _async_resolve_default_pipeline_settings * Improve assist_pipeline test coverage * Improve cloud test coverage
This commit is contained in:
parent
74e0341d83
commit
57a59d808b
14 changed files with 303 additions and 68 deletions
|
@ -17,13 +17,17 @@ from .pipeline import (
|
||||||
PipelineInput,
|
PipelineInput,
|
||||||
PipelineRun,
|
PipelineRun,
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
|
async_create_default_pipeline,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
|
async_get_pipelines,
|
||||||
async_setup_pipeline_store,
|
async_setup_pipeline_store,
|
||||||
)
|
)
|
||||||
from .websocket_api import async_register_websocket_api
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
|
"async_create_default_pipeline",
|
||||||
|
"async_get_pipelines",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_pipeline_from_audio_stream",
|
"async_pipeline_from_audio_stream",
|
||||||
"Pipeline",
|
"Pipeline",
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable, Callable
|
from collections.abc import AsyncIterable, Callable, Iterable
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -75,20 +75,22 @@ STORED_PIPELINE_RUNS = 10
|
||||||
SAVE_DELAY = 10
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
async def _async_create_default_pipeline(
|
async def _async_resolve_default_pipeline_settings(
|
||||||
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
|
hass: HomeAssistant,
|
||||||
) -> Pipeline:
|
stt_engine_id: str | None,
|
||||||
"""Create a default pipeline.
|
tts_engine_id: str | None,
|
||||||
|
) -> dict[str, str | None]:
|
||||||
|
"""Resolve settings for a default pipeline.
|
||||||
|
|
||||||
The default pipeline will use the homeassistant conversation agent and the
|
The default pipeline will use the homeassistant conversation agent and the
|
||||||
default stt / tts engines.
|
default stt / tts engines if none are specified.
|
||||||
"""
|
"""
|
||||||
conversation_language = "en"
|
conversation_language = "en"
|
||||||
pipeline_language = "en"
|
pipeline_language = "en"
|
||||||
pipeline_name = "Home Assistant"
|
pipeline_name = "Home Assistant"
|
||||||
stt_engine_id = None
|
stt_engine = None
|
||||||
stt_language = None
|
stt_language = None
|
||||||
tts_engine_id = None
|
tts_engine = None
|
||||||
tts_language = None
|
tts_language = None
|
||||||
tts_voice = None
|
tts_voice = None
|
||||||
|
|
||||||
|
@ -104,12 +106,15 @@ async def _async_create_default_pipeline(
|
||||||
pipeline_language = hass.config.language
|
pipeline_language = hass.config.language
|
||||||
conversation_language = conversation_languages[0]
|
conversation_language = conversation_languages[0]
|
||||||
|
|
||||||
if (stt_engine_id := stt.async_default_engine(hass)) is not None and (
|
if stt_engine_id is None:
|
||||||
stt_engine := stt.async_get_speech_to_text_engine(
|
stt_engine_id = stt.async_default_engine(hass)
|
||||||
hass,
|
|
||||||
stt_engine_id,
|
if stt_engine_id is not None:
|
||||||
)
|
stt_engine = stt.async_get_speech_to_text_engine(hass, stt_engine_id)
|
||||||
):
|
if stt_engine is None:
|
||||||
|
stt_engine_id = None
|
||||||
|
|
||||||
|
if stt_engine:
|
||||||
stt_languages = language_util.matches(
|
stt_languages = language_util.matches(
|
||||||
pipeline_language,
|
pipeline_language,
|
||||||
stt_engine.supported_languages,
|
stt_engine.supported_languages,
|
||||||
|
@ -125,12 +130,15 @@ async def _async_create_default_pipeline(
|
||||||
)
|
)
|
||||||
stt_engine_id = None
|
stt_engine_id = None
|
||||||
|
|
||||||
if (tts_engine_id := tts.async_default_engine(hass)) is not None and (
|
if tts_engine_id is None:
|
||||||
tts_engine := tts.get_engine_instance(
|
tts_engine_id = tts.async_default_engine(hass)
|
||||||
hass,
|
|
||||||
tts_engine_id,
|
if tts_engine_id is not None:
|
||||||
)
|
tts_engine = tts.get_engine_instance(hass, tts_engine_id)
|
||||||
):
|
if tts_engine is None:
|
||||||
|
tts_engine_id = None
|
||||||
|
|
||||||
|
if tts_engine:
|
||||||
tts_languages = language_util.matches(
|
tts_languages = language_util.matches(
|
||||||
pipeline_language,
|
pipeline_language,
|
||||||
tts_engine.supported_languages,
|
tts_engine.supported_languages,
|
||||||
|
@ -152,19 +160,50 @@ async def _async_create_default_pipeline(
|
||||||
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
|
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
|
||||||
pipeline_name = "Home Assistant Cloud"
|
pipeline_name = "Home Assistant Cloud"
|
||||||
|
|
||||||
return await pipeline_store.async_create_item(
|
return {
|
||||||
{
|
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
||||||
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
"conversation_language": conversation_language,
|
||||||
"conversation_language": conversation_language,
|
"language": hass.config.language,
|
||||||
"language": hass.config.language,
|
"name": pipeline_name,
|
||||||
"name": pipeline_name,
|
"stt_engine": stt_engine_id,
|
||||||
"stt_engine": stt_engine_id,
|
"stt_language": stt_language,
|
||||||
"stt_language": stt_language,
|
"tts_engine": tts_engine_id,
|
||||||
"tts_engine": tts_engine_id,
|
"tts_language": tts_language,
|
||||||
"tts_language": tts_language,
|
"tts_voice": tts_voice,
|
||||||
"tts_voice": tts_voice,
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
async def _async_create_default_pipeline(
|
||||||
|
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
|
||||||
|
) -> Pipeline:
|
||||||
|
"""Create a default pipeline.
|
||||||
|
|
||||||
|
The default pipeline will use the homeassistant conversation agent and the
|
||||||
|
default stt / tts engines.
|
||||||
|
"""
|
||||||
|
pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None)
|
||||||
|
return await pipeline_store.async_create_item(pipeline_settings)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_create_default_pipeline(
|
||||||
|
hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str
|
||||||
|
) -> Pipeline | None:
|
||||||
|
"""Create a pipeline with default settings.
|
||||||
|
|
||||||
|
The default pipeline will use the homeassistant conversation agent and the
|
||||||
|
specified stt / tts engines.
|
||||||
|
"""
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_settings = await _async_resolve_default_pipeline_settings(
|
||||||
|
hass, stt_engine_id, tts_engine_id
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
pipeline_settings["stt_engine"] != stt_engine_id
|
||||||
|
or pipeline_settings["tts_engine"] != tts_engine_id
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
return await pipeline_store.async_create_item(pipeline_settings)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -181,6 +220,14 @@ def async_get_pipeline(
|
||||||
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
|
||||||
|
"""Get all pipelines."""
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
return pipeline_data.pipeline_store.data.values()
|
||||||
|
|
||||||
|
|
||||||
class PipelineEventType(StrEnum):
|
class PipelineEventType(StrEnum):
|
||||||
"""Event types emitted during a pipeline run."""
|
"""Event types emitted during a pipeline run."""
|
||||||
|
|
||||||
|
|
|
@ -238,9 +238,27 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
await prefs.async_initialize()
|
await prefs.async_initialize()
|
||||||
|
|
||||||
# Initialize Cloud
|
# Initialize Cloud
|
||||||
|
loaded = False
|
||||||
|
|
||||||
|
async def _discover_platforms():
|
||||||
|
"""Discover platforms."""
|
||||||
|
nonlocal loaded
|
||||||
|
|
||||||
|
# Prevent multiple discovery
|
||||||
|
if loaded:
|
||||||
|
return
|
||||||
|
loaded = True
|
||||||
|
|
||||||
|
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
|
||||||
|
await async_load_platform(hass, Platform.STT, DOMAIN, {}, config)
|
||||||
|
await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config)
|
||||||
|
|
||||||
websession = async_get_clientsession(hass)
|
websession = async_get_clientsession(hass)
|
||||||
client = CloudClient(hass, prefs, websession, alexa_conf, google_conf)
|
client = CloudClient(
|
||||||
|
hass, prefs, websession, alexa_conf, google_conf, _discover_platforms
|
||||||
|
)
|
||||||
cloud = hass.data[DOMAIN] = Cloud(client, **kwargs)
|
cloud = hass.data[DOMAIN] = Cloud(client, **kwargs)
|
||||||
|
cloud.iot.register_on_connect(client.on_cloud_connected)
|
||||||
|
|
||||||
async def _shutdown(event):
|
async def _shutdown(event):
|
||||||
"""Shutdown event."""
|
"""Shutdown event."""
|
||||||
|
@ -262,8 +280,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
|
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded = False
|
|
||||||
|
|
||||||
async def async_startup_repairs(_=None) -> None:
|
async def async_startup_repairs(_=None) -> None:
|
||||||
"""Create repair issues after startup."""
|
"""Create repair issues after startup."""
|
||||||
if not cloud.is_logged_in:
|
if not cloud.is_logged_in:
|
||||||
|
@ -272,23 +288,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
if subscription_info := await async_subscription_info(cloud):
|
if subscription_info := await async_subscription_info(cloud):
|
||||||
async_manage_legacy_subscription_issue(hass, subscription_info)
|
async_manage_legacy_subscription_issue(hass, subscription_info)
|
||||||
|
|
||||||
async def _discover_platforms():
|
|
||||||
"""Discover platforms."""
|
|
||||||
nonlocal loaded
|
|
||||||
|
|
||||||
# Prevent multiple discovery
|
|
||||||
if loaded:
|
|
||||||
return
|
|
||||||
loaded = True
|
|
||||||
|
|
||||||
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
|
|
||||||
await async_load_platform(hass, Platform.STT, DOMAIN, {}, config)
|
|
||||||
await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config)
|
|
||||||
|
|
||||||
async def _on_connect():
|
async def _on_connect():
|
||||||
"""Handle cloud connect."""
|
"""Handle cloud connect."""
|
||||||
await _discover_platforms()
|
|
||||||
|
|
||||||
async_dispatcher_send(
|
async_dispatcher_send(
|
||||||
hass, SIGNAL_CLOUD_CONNECTION_STATE, CloudConnectionState.CLOUD_CONNECTED
|
hass, SIGNAL_CLOUD_CONNECTION_STATE, CloudConnectionState.CLOUD_CONNECTED
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -10,7 +11,13 @@ from typing import Any
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from hass_nabucasa.client import CloudClient as Interface
|
from hass_nabucasa.client import CloudClient as Interface
|
||||||
|
|
||||||
from homeassistant.components import google_assistant, persistent_notification, webhook
|
from homeassistant.components import (
|
||||||
|
assist_pipeline,
|
||||||
|
conversation,
|
||||||
|
google_assistant,
|
||||||
|
persistent_notification,
|
||||||
|
webhook,
|
||||||
|
)
|
||||||
from homeassistant.components.alexa import (
|
from homeassistant.components.alexa import (
|
||||||
errors as alexa_errors,
|
errors as alexa_errors,
|
||||||
smart_home as alexa_smart_home,
|
smart_home as alexa_smart_home,
|
||||||
|
@ -36,6 +43,7 @@ class CloudClient(Interface):
|
||||||
websession: aiohttp.ClientSession,
|
websession: aiohttp.ClientSession,
|
||||||
alexa_user_config: dict[str, Any],
|
alexa_user_config: dict[str, Any],
|
||||||
google_user_config: dict[str, Any],
|
google_user_config: dict[str, Any],
|
||||||
|
on_started_cb: Callable[[], Coroutine[Any, Any, None]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize client interface to Cloud."""
|
"""Initialize client interface to Cloud."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
|
@ -48,6 +56,10 @@ class CloudClient(Interface):
|
||||||
self._alexa_config_init_lock = asyncio.Lock()
|
self._alexa_config_init_lock = asyncio.Lock()
|
||||||
self._google_config_init_lock = asyncio.Lock()
|
self._google_config_init_lock = asyncio.Lock()
|
||||||
self._relayer_region: str | None = None
|
self._relayer_region: str | None = None
|
||||||
|
self._on_started_cb = on_started_cb
|
||||||
|
self.cloud_pipeline = self._cloud_assist_pipeline()
|
||||||
|
self.stt_platform_loaded = asyncio.Event()
|
||||||
|
self.tts_platform_loaded = asyncio.Event()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_path(self) -> Path:
|
def base_path(self) -> Path:
|
||||||
|
@ -136,8 +148,24 @@ class CloudClient(Interface):
|
||||||
|
|
||||||
return self._google_config
|
return self._google_config
|
||||||
|
|
||||||
async def cloud_started(self) -> None:
|
def _cloud_assist_pipeline(self) -> str | None:
|
||||||
"""When cloud is started."""
|
"""Return the ID of a cloud-enabled assist pipeline or None."""
|
||||||
|
for pipeline in assist_pipeline.async_get_pipelines(self._hass):
|
||||||
|
if (
|
||||||
|
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
|
||||||
|
and pipeline.stt_engine == DOMAIN
|
||||||
|
and pipeline.tts_engine == DOMAIN
|
||||||
|
):
|
||||||
|
return pipeline.id
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_cloud_assist_pipeline(self) -> None:
|
||||||
|
"""Create a cloud-enabled assist pipeline."""
|
||||||
|
await assist_pipeline.async_create_default_pipeline(self._hass, DOMAIN, DOMAIN)
|
||||||
|
self.cloud_pipeline = self._cloud_assist_pipeline()
|
||||||
|
|
||||||
|
async def on_cloud_connected(self) -> None:
|
||||||
|
"""When cloud is connected."""
|
||||||
is_new_user = await self.prefs.async_set_username(self.cloud.username)
|
is_new_user = await self.prefs.async_set_username(self.cloud.username)
|
||||||
|
|
||||||
async def enable_alexa(_):
|
async def enable_alexa(_):
|
||||||
|
@ -181,6 +209,14 @@ class CloudClient(Interface):
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*(task(None) for task in tasks))
|
await asyncio.gather(*(task(None) for task in tasks))
|
||||||
|
|
||||||
|
async def cloud_started(self) -> None:
|
||||||
|
"""When cloud is started."""
|
||||||
|
await self._on_started_cb()
|
||||||
|
await asyncio.gather(
|
||||||
|
self.stt_platform_loaded.wait(),
|
||||||
|
self.tts_platform_loaded.wait(),
|
||||||
|
)
|
||||||
|
|
||||||
async def cloud_stopped(self) -> None:
|
async def cloud_stopped(self) -> None:
|
||||||
"""When the cloud is stopped."""
|
"""When the cloud is stopped."""
|
||||||
|
|
||||||
|
|
|
@ -186,7 +186,11 @@ class CloudLoginView(HomeAssistantView):
|
||||||
cloud = hass.data[DOMAIN]
|
cloud = hass.data[DOMAIN]
|
||||||
await cloud.login(data["email"], data["password"])
|
await cloud.login(data["email"], data["password"])
|
||||||
|
|
||||||
return self.json({"success": True})
|
if cloud.client.cloud_pipeline is None:
|
||||||
|
await cloud.client.create_cloud_assist_pipeline()
|
||||||
|
return self.json(
|
||||||
|
{"success": True, "cloud_pipeline": cloud.client.cloud_pipeline}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CloudLogoutView(HomeAssistantView):
|
class CloudLogoutView(HomeAssistantView):
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
"name": "Home Assistant Cloud",
|
"name": "Home Assistant Cloud",
|
||||||
"after_dependencies": ["google_assistant", "alexa"],
|
"after_dependencies": ["google_assistant", "alexa"],
|
||||||
"codeowners": ["@home-assistant/cloud"],
|
"codeowners": ["@home-assistant/cloud"],
|
||||||
"dependencies": ["homeassistant", "http", "webhook"],
|
"dependencies": ["assist_pipeline", "homeassistant", "http", "webhook"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/cloud",
|
"documentation": "https://www.home-assistant.io/integrations/cloud",
|
||||||
"integration_type": "system",
|
"integration_type": "system",
|
||||||
"iot_class": "cloud_push",
|
"iot_class": "cloud_push",
|
||||||
|
|
|
@ -28,7 +28,9 @@ async def async_get_engine(hass, config, discovery_info=None):
|
||||||
"""Set up Cloud speech component."""
|
"""Set up Cloud speech component."""
|
||||||
cloud: Cloud = hass.data[DOMAIN]
|
cloud: Cloud = hass.data[DOMAIN]
|
||||||
|
|
||||||
return CloudProvider(cloud)
|
cloud_provider = CloudProvider(cloud)
|
||||||
|
cloud.client.stt_platform_loaded.set()
|
||||||
|
return cloud_provider
|
||||||
|
|
||||||
|
|
||||||
class CloudProvider(Provider):
|
class CloudProvider(Provider):
|
||||||
|
|
|
@ -63,7 +63,9 @@ async def async_get_engine(hass, config, discovery_info=None):
|
||||||
language = config[CONF_LANG]
|
language = config[CONF_LANG]
|
||||||
gender = config[ATTR_GENDER]
|
gender = config[ATTR_GENDER]
|
||||||
|
|
||||||
return CloudProvider(cloud, language, gender)
|
cloud_provider = CloudProvider(cloud, language, gender)
|
||||||
|
cloud.client.tts_platform_loaded.set()
|
||||||
|
return cloud_provider
|
||||||
|
|
||||||
|
|
||||||
class CloudProvider(Provider):
|
class CloudProvider(Provider):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import ANY, AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -11,7 +11,9 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineStorageCollection,
|
PipelineStorageCollection,
|
||||||
|
async_create_default_pipeline,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
|
async_get_pipelines,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
|
@ -143,6 +145,31 @@ async def test_loading_datasets_from_storage(
|
||||||
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_default_pipeline(
|
||||||
|
hass: HomeAssistant, init_supporting_components
|
||||||
|
) -> None:
|
||||||
|
"""Test async_create_default_pipeline."""
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
assert await async_create_default_pipeline(hass, "bla", "bla") is None
|
||||||
|
assert await async_create_default_pipeline(hass, "test", "test") == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language="en",
|
||||||
|
id=ANY,
|
||||||
|
language="en",
|
||||||
|
name="Home Assistant",
|
||||||
|
stt_engine="test",
|
||||||
|
stt_language="en-US",
|
||||||
|
tts_engine="test",
|
||||||
|
tts_language="en-US",
|
||||||
|
tts_voice="james_earl_jones",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_pipeline(hass: HomeAssistant) -> None:
|
async def test_get_pipeline(hass: HomeAssistant) -> None:
|
||||||
"""Test async_get_pipeline."""
|
"""Test async_get_pipeline."""
|
||||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
@ -159,6 +186,31 @@ async def test_get_pipeline(hass: HomeAssistant) -> None:
|
||||||
assert pipeline is async_get_pipeline(hass, pipeline.id)
|
assert pipeline is async_get_pipeline(hass, pipeline.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_pipelines(hass: HomeAssistant) -> None:
|
||||||
|
"""Test async_get_pipelines."""
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
pipelines = async_get_pipelines(hass)
|
||||||
|
assert list(pipelines) == [
|
||||||
|
Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language="en",
|
||||||
|
id=ANY,
|
||||||
|
language="en",
|
||||||
|
name="Home Assistant",
|
||||||
|
stt_engine=None,
|
||||||
|
stt_language=None,
|
||||||
|
tts_engine=None,
|
||||||
|
tts_language=None,
|
||||||
|
tts_voice=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("ha_language", "ha_country", "conv_language", "pipeline_language"),
|
("ha_language", "ha_country", "conv_language", "pipeline_language"),
|
||||||
[
|
[
|
||||||
|
|
|
@ -8,6 +8,13 @@ from homeassistant.components.cloud import const, prefs
|
||||||
|
|
||||||
from . import mock_cloud, mock_cloud_prefs
|
from . import mock_cloud, mock_cloud_prefs
|
||||||
|
|
||||||
|
# Prevent TTS cache from being created
|
||||||
|
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||||
|
init_cache_dir_side_effect,
|
||||||
|
mock_get_cache_files,
|
||||||
|
mock_init_cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_user_data():
|
def mock_user_data():
|
||||||
|
|
|
@ -6,6 +6,11 @@ import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
Pipeline,
|
||||||
|
async_get_pipeline,
|
||||||
|
async_get_pipelines,
|
||||||
|
)
|
||||||
from homeassistant.components.cloud import DOMAIN
|
from homeassistant.components.cloud import DOMAIN
|
||||||
from homeassistant.components.cloud.client import CloudClient
|
from homeassistant.components.cloud.client import CloudClient
|
||||||
from homeassistant.components.cloud.const import (
|
from homeassistant.components.cloud.const import (
|
||||||
|
@ -298,23 +303,31 @@ async def test_google_config_should_2fa(
|
||||||
assert not gconf.should_2fa(state)
|
assert not gconf.should_2fa(state)
|
||||||
|
|
||||||
|
|
||||||
async def test_set_username(hass: HomeAssistant) -> None:
|
@patch(
|
||||||
|
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
|
async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None:
|
||||||
"""Test we set username during login."""
|
"""Test we set username during login."""
|
||||||
prefs = MagicMock(
|
prefs = MagicMock(
|
||||||
alexa_enabled=False,
|
alexa_enabled=False,
|
||||||
google_enabled=False,
|
google_enabled=False,
|
||||||
async_set_username=AsyncMock(return_value=None),
|
async_set_username=AsyncMock(return_value=None),
|
||||||
)
|
)
|
||||||
client = CloudClient(hass, prefs, None, {}, {})
|
client = CloudClient(hass, prefs, None, {}, {}, AsyncMock())
|
||||||
client.cloud = MagicMock(is_logged_in=True, username="mock-username")
|
client.cloud = MagicMock(is_logged_in=True, username="mock-username")
|
||||||
await client.cloud_started()
|
await client.on_cloud_connected()
|
||||||
|
|
||||||
assert len(prefs.async_set_username.mock_calls) == 1
|
assert len(prefs.async_set_username.mock_calls) == 1
|
||||||
assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username"
|
assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username"
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
async def test_login_recovers_bad_internet(
|
async def test_login_recovers_bad_internet(
|
||||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
async_get_pipelines, hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test Alexa can recover bad auth."""
|
"""Test Alexa can recover bad auth."""
|
||||||
prefs = Mock(
|
prefs = Mock(
|
||||||
|
@ -322,12 +335,12 @@ async def test_login_recovers_bad_internet(
|
||||||
google_enabled=False,
|
google_enabled=False,
|
||||||
async_set_username=AsyncMock(return_value=None),
|
async_set_username=AsyncMock(return_value=None),
|
||||||
)
|
)
|
||||||
client = CloudClient(hass, prefs, None, {}, {})
|
client = CloudClient(hass, prefs, None, {}, {}, AsyncMock())
|
||||||
client.cloud = Mock()
|
client.cloud = Mock()
|
||||||
client._alexa_config = Mock(
|
client._alexa_config = Mock(
|
||||||
async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError)
|
async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError)
|
||||||
)
|
)
|
||||||
await client.cloud_started()
|
await client.on_cloud_connected()
|
||||||
assert len(client._alexa_config.async_enable_proactive_mode.mock_calls) == 1
|
assert len(client._alexa_config.async_enable_proactive_mode.mock_calls) == 1
|
||||||
assert "Unable to activate Alexa Report State" in caplog.text
|
assert "Unable to activate Alexa Report State" in caplog.text
|
||||||
|
|
||||||
|
@ -354,3 +367,29 @@ async def test_system_msg(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
assert response is None
|
assert response is None
|
||||||
assert cloud.client.relayer_region == "xx-earth-616"
|
assert cloud.client.relayer_region == "xx-earth-616"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_cloud_assist_pipeline(
|
||||||
|
hass: HomeAssistant, mock_cloud_setup, mock_cloud_login
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a cloud enabled assist pipeline."""
|
||||||
|
cloud_client: CloudClient = hass.data[DOMAIN].client
|
||||||
|
await cloud_client.cloud_started()
|
||||||
|
assert cloud_client.cloud_pipeline is None
|
||||||
|
assert len(async_get_pipelines(hass)) == 1
|
||||||
|
|
||||||
|
await cloud_client.create_cloud_assist_pipeline()
|
||||||
|
assert cloud_client.cloud_pipeline is not None
|
||||||
|
assert len(async_get_pipelines(hass)) == 2
|
||||||
|
assert async_get_pipeline(hass, cloud_client.cloud_pipeline) == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language="en",
|
||||||
|
id=cloud_client.cloud_pipeline,
|
||||||
|
language="en",
|
||||||
|
name="Home Assistant Cloud",
|
||||||
|
stt_engine="cloud",
|
||||||
|
stt_language="en-US",
|
||||||
|
tts_engine="cloud",
|
||||||
|
tts_language="en-US",
|
||||||
|
tts_voice="JennyNeural",
|
||||||
|
)
|
||||||
|
|
|
@ -105,7 +105,14 @@ async def test_google_actions_sync_fails(
|
||||||
|
|
||||||
async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
|
async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
|
||||||
"""Test logging in."""
|
"""Test logging in."""
|
||||||
hass.data["cloud"] = MagicMock(login=AsyncMock())
|
create_cloud_assist_pipeline_mock = AsyncMock()
|
||||||
|
hass.data["cloud"] = MagicMock(
|
||||||
|
login=AsyncMock(),
|
||||||
|
client=Mock(
|
||||||
|
cloud_pipeline="12345",
|
||||||
|
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
req = await cloud_client.post(
|
req = await cloud_client.post(
|
||||||
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
|
@ -113,7 +120,29 @@ async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
|
||||||
|
|
||||||
assert req.status == HTTPStatus.OK
|
assert req.status == HTTPStatus.OK
|
||||||
result = await req.json()
|
result = await req.json()
|
||||||
assert result == {"success": True}
|
assert result == {"success": True, "cloud_pipeline": "12345"}
|
||||||
|
create_cloud_assist_pipeline_mock.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None:
|
||||||
|
"""Test logging in when no assist pipeline is available."""
|
||||||
|
create_cloud_assist_pipeline_mock = AsyncMock()
|
||||||
|
hass.data["cloud"] = MagicMock(
|
||||||
|
login=AsyncMock(),
|
||||||
|
client=Mock(
|
||||||
|
cloud_pipeline=None,
|
||||||
|
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.OK
|
||||||
|
result = await req.json()
|
||||||
|
assert result == {"success": True, "cloud_pipeline": None}
|
||||||
|
create_cloud_assist_pipeline_mock.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
async def test_login_view_random_exception(cloud_client) -> None:
|
async def test_login_view_random_exception(cloud_client) -> None:
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from hass_nabucasa import Cloud
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import cloud
|
from homeassistant.components import cloud
|
||||||
|
@ -134,9 +135,9 @@ async def test_setup_existing_cloud_user(
|
||||||
|
|
||||||
async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
|
async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
|
||||||
"""Test cloud on connect triggers."""
|
"""Test cloud on connect triggers."""
|
||||||
cl = hass.data["cloud"]
|
cl: Cloud = hass.data["cloud"]
|
||||||
|
|
||||||
assert len(cl.iot._on_connect) == 3
|
assert len(cl.iot._on_connect) == 4
|
||||||
|
|
||||||
assert len(hass.states.async_entity_ids("binary_sensor")) == 0
|
assert len(hass.states.async_entity_ids("binary_sensor")) == 0
|
||||||
|
|
||||||
|
@ -152,6 +153,11 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
|
||||||
await cl.iot._on_connect[-1]()
|
await cl.iot._on_connect[-1]()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(hass.states.async_entity_ids("binary_sensor")) == 0
|
||||||
|
|
||||||
|
await cl.client.cloud_started()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(hass.states.async_entity_ids("binary_sensor")) == 1
|
assert len(hass.states.async_entity_ids("binary_sensor")) == 1
|
||||||
|
|
||||||
with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load:
|
with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load:
|
||||||
|
|
|
@ -511,6 +511,12 @@ async def test_get_engine_legacy(
|
||||||
TEST_DOMAIN,
|
TEST_DOMAIN,
|
||||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
)
|
)
|
||||||
|
mock_stt_platform(
|
||||||
|
hass,
|
||||||
|
tmp_path,
|
||||||
|
"cloud",
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
|
)
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue