Move cloud stt engine to config entry (#99608)

* Migrate cloud stt to config entry

* Update default engine

* Test config flow

* Migrate pipelines with cloud stt engine to new engine id

* Fix test after rebase

* Update and add comment

* Remove cloud specifics from default stt engine

* Refactor cloud assist pipeline

* Fix cloud stt entity_id

* Try to wait for platforms before creating default pipeline

* Clean up import

* Move function in cloud assist pipeline

* Wait for tts platform loaded in stt migration

* Update deprecation dates

* Clean up not used fixture

* Add test for async_update_pipeline

* Define pipeline update interface better

* Remove leftover

* Fix tests

* Change default engine test

* Add test for missing stt entity during login

* Add and update comments

* Update config entry title
This commit is contained in:
Martin Hjelmare 2023-12-21 13:39:02 +01:00 committed by GitHub
parent f0104d6851
commit e1f31194f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 650 additions and 61 deletions

View file

@ -31,6 +31,7 @@ from .pipeline import (
async_get_pipeline, async_get_pipeline,
async_get_pipelines, async_get_pipelines,
async_setup_pipeline_store, async_setup_pipeline_store,
async_update_pipeline,
) )
from .websocket_api import async_register_websocket_api from .websocket_api import async_register_websocket_api
@ -40,6 +41,7 @@ __all__ = (
"async_get_pipelines", "async_get_pipelines",
"async_setup", "async_setup",
"async_pipeline_from_audio_stream", "async_pipeline_from_audio_stream",
"async_update_pipeline",
"AudioSettings", "AudioSettings",
"Pipeline", "Pipeline",
"PipelineEvent", "PipelineEvent",

View file

@ -43,6 +43,7 @@ from homeassistant.helpers.collection import (
) )
from homeassistant.helpers.singleton import singleton from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
from homeassistant.util import ( from homeassistant.util import (
dt as dt_util, dt as dt_util,
language as language_util, language as language_util,
@ -276,6 +277,48 @@ def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
return pipeline_data.pipeline_store.data.values() return pipeline_data.pipeline_store.data.values()
async def async_update_pipeline(
hass: HomeAssistant,
pipeline: Pipeline,
*,
conversation_engine: str | UndefinedType = UNDEFINED,
conversation_language: str | UndefinedType = UNDEFINED,
language: str | UndefinedType = UNDEFINED,
name: str | UndefinedType = UNDEFINED,
stt_engine: str | None | UndefinedType = UNDEFINED,
stt_language: str | None | UndefinedType = UNDEFINED,
tts_engine: str | None | UndefinedType = UNDEFINED,
tts_language: str | None | UndefinedType = UNDEFINED,
tts_voice: str | None | UndefinedType = UNDEFINED,
wake_word_entity: str | None | UndefinedType = UNDEFINED,
wake_word_id: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Update a pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
updates: dict[str, Any] = pipeline.to_json()
updates.pop("id")
# Refactor this once we bump to Python 3.12
# and have https://peps.python.org/pep-0692/
for key, val in (
("conversation_engine", conversation_engine),
("conversation_language", conversation_language),
("language", language),
("name", name),
("stt_engine", stt_engine),
("stt_language", stt_language),
("tts_engine", tts_engine),
("tts_language", tts_language),
("tts_voice", tts_voice),
("wake_word_entity", wake_word_entity),
("wake_word_id", wake_word_id),
):
if val is not UNDEFINED:
updates[key] = val
await pipeline_data.pipeline_store.async_update_item(pipeline.id, updates)
class PipelineEventType(StrEnum): class PipelineEventType(StrEnum):
"""Event types emitted during a pipeline run.""" """Event types emitted during a pipeline run."""

View file

@ -10,6 +10,7 @@ from hass_nabucasa import Cloud
import voluptuous as vol import voluptuous as vol
from homeassistant.components import alexa, google_assistant from homeassistant.components import alexa, google_assistant
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_DESCRIPTION, CONF_DESCRIPTION,
CONF_MODE, CONF_MODE,
@ -51,6 +52,7 @@ from .const import (
CONF_SERVICEHANDLERS_SERVER, CONF_SERVICEHANDLERS_SERVER,
CONF_THINGTALK_SERVER, CONF_THINGTALK_SERVER,
CONF_USER_POOL_ID, CONF_USER_POOL_ID,
DATA_PLATFORMS_SETUP,
DOMAIN, DOMAIN,
MODE_DEV, MODE_DEV,
MODE_PROD, MODE_PROD,
@ -61,6 +63,8 @@ from .subscription import async_subscription_info
DEFAULT_MODE = MODE_PROD DEFAULT_MODE = MODE_PROD
PLATFORMS = [Platform.STT]
SERVICE_REMOTE_CONNECT = "remote_connect" SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect" SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
@ -262,6 +266,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_manage_legacy_subscription_issue(hass, subscription_info) async_manage_legacy_subscription_issue(hass, subscription_info)
loaded = False loaded = False
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
hass.data[DATA_PLATFORMS_SETUP] = {
Platform.STT: stt_platform_loaded,
Platform.TTS: tts_platform_loaded,
}
async def _on_start() -> None: async def _on_start() -> None:
"""Discover platforms.""" """Discover platforms."""
@ -272,15 +282,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return return
loaded = True loaded = True
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
stt_info = {"platform_loaded": stt_platform_loaded}
tts_info = {"platform_loaded": tts_platform_loaded} tts_info = {"platform_loaded": tts_platform_loaded}
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config) await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config) await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config)
await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait()) await tts_platform_loaded.wait()
# The config entry should be loaded after the legacy tts platform is loaded
# to make sure that the tts integration is setup before we try to migrate
# old assist pipelines in the cloud stt entity.
await hass.config_entries.flow.async_init(DOMAIN, context={"source": "system"})
async def _on_connect() -> None: async def _on_connect() -> None:
"""Handle cloud connect.""" """Handle cloud connect."""
@ -304,7 +315,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
cloud.register_on_initialized(_on_initialized) cloud.register_on_initialized(_on_initialized)
await cloud.initialize() await cloud.initialize()
await http_api.async_setup(hass) http_api.async_setup(hass)
account_link.async_setup(hass) account_link.async_setup(hass)
@ -340,3 +351,19 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None:
await cloud.remote.disconnect() await cloud.remote.disconnect()
cloud.client.prefs.async_listen_updates(remote_prefs_updated) cloud.client.prefs.async_listen_updates(remote_prefs_updated)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
return unload_ok

View file

@ -1,31 +1,48 @@
"""Handle Cloud assist pipelines.""" """Handle Cloud assist pipelines."""
import asyncio
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
async_create_default_pipeline, async_create_default_pipeline,
async_get_pipelines, async_get_pipelines,
async_setup_pipeline_store, async_setup_pipeline_store,
async_update_pipeline,
) )
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er
from .const import DOMAIN from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
"""Create a cloud assist pipeline.""" """Create a cloud assist pipeline."""
# Wait for stt and tts platforms to set up before creating the pipeline.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
# Make sure the pipeline store is loaded, needed because assist_pipeline # Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud # is an after dependency of cloud
await async_setup_pipeline_store(hass) await async_setup_pipeline_store(hass)
entity_registry = er.async_get(hass)
new_stt_engine_id = entity_registry.async_get_entity_id(
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
)
if new_stt_engine_id is None:
# If there's no cloud stt entity, we can't create a cloud pipeline.
return None
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None. """Return the ID of a cloud-enabled assist pipeline or None.
Check if a cloud pipeline already exists with Check if a cloud pipeline already exists with either
legacy cloud engine id. legacy or current cloud engine ids.
""" """
for pipeline in async_get_pipelines(hass): for pipeline in async_get_pipelines(hass):
if ( if (
pipeline.conversation_engine == HOME_ASSISTANT_AGENT pipeline.conversation_engine == HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
and pipeline.tts_engine == DOMAIN and pipeline.tts_engine == DOMAIN
): ):
return pipeline.id return pipeline.id
@ -34,7 +51,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
if (cloud_assist_pipeline(hass)) is not None or ( if (cloud_assist_pipeline(hass)) is not None or (
cloud_pipeline := await async_create_default_pipeline( cloud_pipeline := await async_create_default_pipeline(
hass, hass,
stt_engine_id=DOMAIN, stt_engine_id=new_stt_engine_id,
tts_engine_id=DOMAIN, tts_engine_id=DOMAIN,
pipeline_name="Home Assistant Cloud", pipeline_name="Home Assistant Cloud",
) )
@ -42,3 +59,27 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
return None return None
return cloud_pipeline.id return cloud_pipeline.id
async def async_migrate_cloud_pipeline_stt_engine(
hass: HomeAssistant, stt_engine_id: str
) -> None:
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
# Added in 2024.01.0. Can be removed in 2025.01.0.
# We need to make sure that tts is loaded before this migration.
# Assist pipeline will call default engine of tts when setting up the store.
# Wait for the tts platform loaded event here.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await platforms_setup[Platform.TTS].wait()
# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
pipelines = async_get_pipelines(hass)
for pipeline in pipelines:
if pipeline.stt_engine != DOMAIN:
continue
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)

View file

@ -0,0 +1,23 @@
"""Config flow for the Cloud integration."""
from __future__ import annotations
from typing import Any
from homeassistant.config_entries import ConfigFlow
from homeassistant.data_entry_flow import FlowResult
from .const import DOMAIN
class CloudConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for the Cloud integration."""
VERSION = 1
async def async_step_system(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the system step."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
return self.async_create_entry(title="Home Assistant Cloud", data={})

View file

@ -1,5 +1,6 @@
"""Constants for the cloud component.""" """Constants for the cloud component."""
DOMAIN = "cloud" DOMAIN = "cloud"
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
REQUEST_TIMEOUT = 10 REQUEST_TIMEOUT = 10
PREF_ENABLE_ALEXA = "alexa_enabled" PREF_ENABLE_ALEXA = "alexa_enabled"
@ -64,3 +65,5 @@ MODE_DEV = "development"
MODE_PROD = "production" MODE_PROD = "production"
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update" DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"

View file

@ -28,7 +28,7 @@ from homeassistant.components.homeassistant import exposed_entities
from homeassistant.components.http import HomeAssistantView, require_admin from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util.location import async_detect_location_info from homeassistant.util.location import async_detect_location_info
@ -66,7 +66,8 @@ _CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
} }
async def async_setup(hass: HomeAssistant) -> None: @callback
def async_setup(hass: HomeAssistant) -> None:
"""Initialize the HTTP API.""" """Initialize the HTTP API."""
websocket_api.async_register_command(hass, websocket_cloud_status) websocket_api.async_register_command(hass, websocket_cloud_status)
websocket_api.async_register_command(hass, websocket_subscription) websocket_api.async_register_command(hass, websocket_subscription)

View file

@ -1,4 +1,10 @@
{ {
"config": {
"step": {},
"abort": {
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
}
},
"system_health": { "system_health": {
"info": { "info": {
"can_reach_cert_server": "Reach Certificate Server", "can_reach_cert_server": "Reach Certificate Server",

View file

@ -13,37 +13,38 @@ from homeassistant.components.stt import (
AudioCodecs, AudioCodecs,
AudioFormats, AudioFormats,
AudioSampleRates, AudioSampleRates,
Provider,
SpeechMetadata, SpeechMetadata,
SpeechResult, SpeechResult,
SpeechResultState, SpeechResultState,
SpeechToTextEntity,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
from .client import CloudClient from .client import CloudClient
from .const import DOMAIN from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_get_engine( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config_entry: ConfigEntry,
discovery_info: DiscoveryInfoType | None = None, async_add_entities: AddEntitiesCallback,
) -> CloudProvider: ) -> None:
"""Set up Cloud speech component.""" """Set up Home Assistant Cloud speech platform via config entry."""
cloud: Cloud[CloudClient] = hass.data[DOMAIN] cloud: Cloud[CloudClient] = hass.data[DOMAIN]
async_add_entities([CloudProviderEntity(cloud)])
cloud_provider = CloudProvider(cloud)
if discovery_info is not None:
discovery_info["platform_loaded"].set()
return cloud_provider
class CloudProvider(Provider): class CloudProviderEntity(SpeechToTextEntity):
"""NabuCasa speech API provider.""" """NabuCasa speech API provider."""
_attr_name = "Home Assistant Cloud"
_attr_unique_id = STT_ENTITY_UNIQUE_ID
def __init__(self, cloud: Cloud[CloudClient]) -> None: def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Home Assistant NabuCasa Speech to text.""" """Home Assistant NabuCasa Speech to text."""
self.cloud = cloud self.cloud = cloud
@ -78,6 +79,10 @@ class CloudProvider(Provider):
"""Return a list of supported channels.""" """Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO] return [AudioChannels.CHANNEL_MONO]
async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass."""
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
async def async_process_audio_stream( async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult: ) -> SpeechResult:

View file

@ -29,9 +29,6 @@ _LOGGER = logging.getLogger(__name__)
@callback @callback
def async_default_provider(hass: HomeAssistant) -> str | None: def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider.""" """Return the domain of the default provider."""
if "cloud" in hass.data[DATA_PROVIDERS]:
return "cloud"
return next(iter(hass.data[DATA_PROVIDERS]), None) return next(iter(hass.data[DATA_PROVIDERS]), None)

View file

@ -1,4 +1,5 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_create_default_pipeline, async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines, async_get_pipelines,
async_update_pipeline,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -26,6 +28,13 @@ from .conftest import MockSttProvider, MockTTSProvider
from tests.common import flush_store from tests.common import flush_store
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def load_homeassistant(hass) -> None: async def load_homeassistant(hass) -> None:
"""Load the homeassistant integration.""" """Load the homeassistant integration."""
@ -478,3 +487,125 @@ async def test_default_pipeline_unsupported_tts_language(
wake_word_entity=None, wake_word_entity=None,
wake_word_id=None, wake_word_id=None,
) )
async def test_update_pipeline(
hass: HomeAssistant,
hass_storage: dict[str, Any],
) -> None:
"""Test async_update_pipeline."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
]
pipeline = pipelines[0]
await async_update_pipeline(
hass,
pipeline,
conversation_engine="homeassistant_1",
conversation_language="de",
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
pipeline = pipelines[0]
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_1",
stt_language="de",
tts_engine="test_1",
tts_language="de",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_1",
"stt_language": "de",
"tts_engine": "test_1",
"tts_language": "de",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}
await async_update_pipeline(
hass,
pipeline,
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
)
pipelines = async_get_pipelines(hass)
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant_1",
conversation_language="de",
id=pipeline.id,
language="de",
name="Home Assistant 1",
stt_engine="stt.test_2",
stt_language="en",
tts_engine="test_2",
tts_language="en",
tts_voice="test_voice",
wake_word_entity="wake_work.test_1",
wake_word_id="wake_word_id_1",
)
]
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
"conversation_engine": "homeassistant_1",
"conversation_language": "de",
"id": pipeline.id,
"language": "de",
"name": "Home Assistant 1",
"stt_engine": "stt.test_2",
"stt_language": "en",
"tts_engine": "test_2",
"tts_language": "en",
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}

View file

@ -1,7 +1,8 @@
"""Tests for the cloud component.""" """Tests for the cloud component."""
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from hass_nabucasa import Cloud
from homeassistant.components import cloud from homeassistant.components import cloud
from homeassistant.components.cloud import const, prefs as cloud_prefs from homeassistant.components.cloud import const, prefs as cloud_prefs
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -14,7 +15,7 @@ async def mock_cloud(hass, config=None):
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, cloud.DOMAIN, {"cloud": config or {}}) assert await async_setup_component(hass, cloud.DOMAIN, {"cloud": config or {}})
cloud_inst = hass.data["cloud"] cloud_inst: Cloud = hass.data["cloud"]
with patch("hass_nabucasa.Cloud.run_executor", AsyncMock(return_value=None)): with patch("hass_nabucasa.Cloud.run_executor", AsyncMock(return_value=None)):
await cloud_inst.initialize() await cloud_inst.initialize()

View file

@ -0,0 +1,40 @@
"""Test the Home Assistant Cloud config flow."""
from unittest.mock import patch
from homeassistant.components.cloud.const import DOMAIN
from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry
async def test_config_flow(hass: HomeAssistant) -> None:
"""Test create cloud entry."""
with patch(
"homeassistant.components.cloud.async_setup", return_value=True
) as mock_setup, patch(
"homeassistant.components.cloud.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "system"}
)
assert result["type"] == "create_entry"
assert result["title"] == "Home Assistant Cloud"
assert result["data"] == {}
await hass.async_block_till_done()
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
async def test_multiple_entries(hass: HomeAssistant) -> None:
"""Test creating multiple cloud entries."""
config_entry = MockConfigEntry(domain=DOMAIN)
config_entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "system"}
)
assert result["type"] == "abort"
assert result["reason"] == "single_instance_allowed"

View file

@ -46,6 +46,26 @@ PIPELINE_DATA_LEGACY = {
"preferred_item": "12345", "preferred_item": "12345",
} }
PIPELINE_DATA = {
"items": [
{
"conversation_engine": "homeassistant",
"conversation_language": "language_1",
"id": "12345",
"language": "language_1",
"name": "Home Assistant Cloud",
"stt_engine": "stt.home_assistant_cloud",
"stt_language": "language_1",
"tts_engine": "cloud",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
},
],
"preferred_item": "12345",
}
PIPELINE_DATA_OTHER = { PIPELINE_DATA_OTHER = {
"items": [ "items": [
{ {
@ -127,7 +147,34 @@ async def test_google_actions_sync_fails(
assert mock_request_sync.call_count == 1 assert mock_request_sync.call_count == 1
@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA_LEGACY]) async def test_login_view_missing_stt_entity(
hass: HomeAssistant,
setup_cloud: None,
entity_registry: er.EntityRegistry,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in when the cloud stt entity is missing."""
# Make sure that the cloud stt entity does not exist.
entity_registry.async_remove("stt.home_assistant_cloud")
await hass.async_block_till_done()
cloud_client = await hass_client()
# We assume the user needs to login again for some reason.
with patch(
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
) as create_pipeline_mock:
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": None}
create_pipeline_mock.assert_not_awaited()
@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA, PIPELINE_DATA_LEGACY])
async def test_login_view_existing_pipeline( async def test_login_view_existing_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
cloud: MagicMock, cloud: MagicMock,
@ -195,7 +242,7 @@ async def test_login_view_create_pipeline(
assert result == {"success": True, "cloud_pipeline": "12345"} assert result == {"success": True, "cloud_pipeline": "12345"}
create_pipeline_mock.assert_awaited_once_with( create_pipeline_mock.assert_awaited_once_with(
hass, hass,
stt_engine_id="cloud", stt_engine_id="stt.home_assistant_cloud",
tts_engine_id="cloud", tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud", pipeline_name="Home Assistant Cloud",
) )
@ -234,7 +281,7 @@ async def test_login_view_create_pipeline_fail(
assert result == {"success": True, "cloud_pipeline": None} assert result == {"success": True, "cloud_pipeline": None}
create_pipeline_mock.assert_awaited_once_with( create_pipeline_mock.assert_awaited_once_with(
hass, hass,
stt_engine_id="cloud", stt_engine_id="stt.home_assistant_cloud",
tts_engine_id="cloud", tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud", pipeline_name="Home Assistant Cloud",
) )

View file

@ -0,0 +1,201 @@
"""Test the speech-to-text platform for the cloud integration."""
from collections.abc import AsyncGenerator
from copy import deepcopy
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.voice import STTResponse, VoiceError
import pytest
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
from homeassistant.components.cloud import DOMAIN
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.typing import ClientSessionGenerator
PIPELINE_DATA = {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "Home Assistant Cloud",
"stt_engine": "cloud",
"stt_language": "language_1",
"tts_engine": "cloud",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
}
@pytest.fixture(autouse=True)
async def load_homeassistant(hass: HomeAssistant) -> None:
"""Load the homeassistant integration."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.mark.parametrize(
("mock_process_stt", "expected_response_data"),
[
(
AsyncMock(return_value=STTResponse(True, "Turn the Kitchen Lights on")),
{"text": "Turn the Kitchen Lights on", "result": "success"},
),
(AsyncMock(side_effect=VoiceError("Boom!")), {"text": None, "result": "error"}),
],
)
async def test_cloud_speech(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
mock_process_stt: AsyncMock,
expected_response_data: dict[str, Any],
) -> None:
"""Test cloud text-to-speech."""
cloud.voice.process_stt = mock_process_stt
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
state = hass.states.get("stt.home_assistant_cloud")
assert state
assert state.state == STATE_UNKNOWN
client = await hass_client()
response = await client.post(
"/api/stt/stt.home_assistant_cloud",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
" language=de-DE"
)
},
data=b"Test",
)
response_data = await response.json()
assert mock_process_stt.call_count == 1
assert (
mock_process_stt.call_args.kwargs["content_type"]
== "audio/wav; codecs=audio/pcm; samplerate=16000"
)
assert mock_process_stt.call_args.kwargs["language"] == "de-DE"
assert response.status == HTTPStatus.OK
assert response_data == expected_response_data
state = hass.states.get("stt.home_assistant_cloud")
assert state
assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
async def test_migrating_pipelines(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
hass_storage: dict[str, Any],
) -> None:
"""Test migrating pipelines when cloud stt entity is added."""
cloud.voice.process_stt = AsyncMock(
return_value=STTResponse(True, "Turn the Kitchen Lights on")
)
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "assist_pipeline.pipelines",
"data": deepcopy(PIPELINE_DATA),
}
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
await hass.async_block_till_done()
state = hass.states.get("stt.home_assistant_cloud")
assert state
assert state.state == STATE_UNKNOWN
# The stt engine should be updated to the new cloud stt engine id.
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
== "stt.home_assistant_cloud"
)
# The other items should stay the same.
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"]
== "conversation_engine_1"
)
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"]
== "language_1"
)
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["id"]
== "01GX8ZWBAQYWNB1XV3EXEZ75DY"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
== "Arnold Schwarzenegger"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None
assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1]
assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2]

View file

@ -121,12 +121,20 @@ class STTFlow(ConfigFlow):
"""Test flow.""" """Test flow."""
@pytest.fixture(autouse=True) @pytest.fixture(name="config_flow_test_domain")
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]: def config_flow_test_domain_fixture() -> str:
"""Mock config flow.""" """Test domain fixture."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow") return TEST_DOMAIN
with mock_config_flow(TEST_DOMAIN, STTFlow):
@pytest.fixture(autouse=True)
def config_flow_fixture(
hass: HomeAssistant, config_flow_test_domain: str
) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{config_flow_test_domain}.config_flow")
with mock_config_flow(config_flow_test_domain, STTFlow):
yield yield
@ -137,6 +145,7 @@ async def setup_fixture(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
) -> MockProvider | MockProviderEntity: ) -> MockProvider | MockProviderEntity:
"""Set up the test environment.""" """Set up the test environment."""
provider: MockProvider | MockProviderEntity
if request.param == "mock_setup": if request.param == "mock_setup":
provider = MockProvider() provider = MockProvider()
await mock_setup(hass, tmp_path, provider) await mock_setup(hass, tmp_path, provider)
@ -166,7 +175,10 @@ async def mock_setup(
async def mock_config_entry_setup( async def mock_config_entry_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Set up a test provider via config entry.""" """Set up a test provider via config entry."""
@ -187,7 +199,7 @@ async def mock_config_entry_setup(
mock_integration( mock_integration(
hass, hass,
MockModule( MockModule(
TEST_DOMAIN, test_domain,
async_setup_entry=async_setup_entry_init, async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init, async_unload_entry=async_unload_entry_init,
), ),
@ -201,9 +213,9 @@ async def mock_config_entry_setup(
"""Set up test stt platform via config entry.""" """Set up test stt platform via config entry."""
async_add_entities([mock_provider_entity]) async_add_entities([mock_provider_entity])
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform) mock_stt_entity_platform(hass, tmp_path, test_domain, async_setup_entry_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN) config_entry = MockConfigEntry(domain=test_domain)
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -456,7 +468,11 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
assert async_default_engine(hass) is None assert async_default_engine(hass) is None
async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None: async def test_default_engine(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
) -> None:
"""Test async_default_engine.""" """Test async_default_engine."""
mock_stt_platform( mock_stt_platform(
hass, hass,
@ -479,26 +495,31 @@ async def test_default_engine_entity(
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}" assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None: @pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
async def test_default_engine_prefer_provider(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
mock_provider: MockProvider,
config_flow_test_domain: str,
) -> None:
"""Test async_default_engine.""" """Test async_default_engine."""
mock_stt_platform( mock_provider_entity.url_path = "stt.new_test"
hass, mock_provider_entity._attr_name = "New test"
tmp_path,
TEST_DOMAIN, await mock_setup(hass, tmp_path, mock_provider)
async_get_engine=AsyncMock(return_value=mock_provider), await mock_config_entry_setup(
) hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
mock_stt_platform(
hass,
tmp_path,
"cloud",
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert async_default_engine(hass) == "cloud" entity_engine = async_get_speech_to_text_engine(hass, "stt.new_test")
assert entity_engine is not None
assert entity_engine.name == "New test"
provider_engine = async_get_speech_to_text_engine(hass, "test")
assert provider_engine is not None
assert provider_engine.name == "test"
assert async_default_engine(hass) == "test"
async def test_get_engine_legacy( async def test_get_engine_legacy(