Move cloud stt engine to config entry (#99608)
* Migrate cloud stt to config entry * Update default engine * Test config flow * Migrate pipelines with cloud stt engine to new engine id * Fix test after rebase * Update and add comment * Remove cloud specifics from default stt engine * Refactor cloud assist pipeline * Fix cloud stt entity_id * Try to wait for platforms before creating default pipeline * Clean up import * Move function in cloud assist pipeline * Wait for tts platform loaded in stt migration * Update deprecation dates * Clean up not used fixture * Add test for async_update_pipeline * Define pipeline update interface better * Remove leftover * Fix tests * Change default engine test * Add test for missing stt entity during login * Add and update comments * Update config entry title
This commit is contained in:
parent
f0104d6851
commit
e1f31194f7
16 changed files with 650 additions and 61 deletions
|
@ -31,6 +31,7 @@ from .pipeline import (
|
|||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_setup_pipeline_store,
|
||||
async_update_pipeline,
|
||||
)
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
|
@ -40,6 +41,7 @@ __all__ = (
|
|||
"async_get_pipelines",
|
||||
"async_setup",
|
||||
"async_pipeline_from_audio_stream",
|
||||
"async_update_pipeline",
|
||||
"AudioSettings",
|
||||
"Pipeline",
|
||||
"PipelineEvent",
|
||||
|
|
|
@ -43,6 +43,7 @@ from homeassistant.helpers.collection import (
|
|||
)
|
||||
from homeassistant.helpers.singleton import singleton
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
||||
from homeassistant.util import (
|
||||
dt as dt_util,
|
||||
language as language_util,
|
||||
|
@ -276,6 +277,48 @@ def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
|
|||
return pipeline_data.pipeline_store.data.values()
|
||||
|
||||
|
||||
async def async_update_pipeline(
|
||||
hass: HomeAssistant,
|
||||
pipeline: Pipeline,
|
||||
*,
|
||||
conversation_engine: str | UndefinedType = UNDEFINED,
|
||||
conversation_language: str | UndefinedType = UNDEFINED,
|
||||
language: str | UndefinedType = UNDEFINED,
|
||||
name: str | UndefinedType = UNDEFINED,
|
||||
stt_engine: str | None | UndefinedType = UNDEFINED,
|
||||
stt_language: str | None | UndefinedType = UNDEFINED,
|
||||
tts_engine: str | None | UndefinedType = UNDEFINED,
|
||||
tts_language: str | None | UndefinedType = UNDEFINED,
|
||||
tts_voice: str | None | UndefinedType = UNDEFINED,
|
||||
wake_word_entity: str | None | UndefinedType = UNDEFINED,
|
||||
wake_word_id: str | None | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Update a pipeline."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
||||
updates: dict[str, Any] = pipeline.to_json()
|
||||
updates.pop("id")
|
||||
# Refactor this once we bump to Python 3.12
|
||||
# and have https://peps.python.org/pep-0692/
|
||||
for key, val in (
|
||||
("conversation_engine", conversation_engine),
|
||||
("conversation_language", conversation_language),
|
||||
("language", language),
|
||||
("name", name),
|
||||
("stt_engine", stt_engine),
|
||||
("stt_language", stt_language),
|
||||
("tts_engine", tts_engine),
|
||||
("tts_language", tts_language),
|
||||
("tts_voice", tts_voice),
|
||||
("wake_word_entity", wake_word_entity),
|
||||
("wake_word_id", wake_word_id),
|
||||
):
|
||||
if val is not UNDEFINED:
|
||||
updates[key] = val
|
||||
|
||||
await pipeline_data.pipeline_store.async_update_item(pipeline.id, updates)
|
||||
|
||||
|
||||
class PipelineEventType(StrEnum):
|
||||
"""Event types emitted during a pipeline run."""
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from hass_nabucasa import Cloud
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import alexa, google_assistant
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
CONF_DESCRIPTION,
|
||||
CONF_MODE,
|
||||
|
@ -51,6 +52,7 @@ from .const import (
|
|||
CONF_SERVICEHANDLERS_SERVER,
|
||||
CONF_THINGTALK_SERVER,
|
||||
CONF_USER_POOL_ID,
|
||||
DATA_PLATFORMS_SETUP,
|
||||
DOMAIN,
|
||||
MODE_DEV,
|
||||
MODE_PROD,
|
||||
|
@ -61,6 +63,8 @@ from .subscription import async_subscription_info
|
|||
|
||||
DEFAULT_MODE = MODE_PROD
|
||||
|
||||
PLATFORMS = [Platform.STT]
|
||||
|
||||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||
|
||||
|
@ -262,6 +266,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
async_manage_legacy_subscription_issue(hass, subscription_info)
|
||||
|
||||
loaded = False
|
||||
stt_platform_loaded = asyncio.Event()
|
||||
tts_platform_loaded = asyncio.Event()
|
||||
hass.data[DATA_PLATFORMS_SETUP] = {
|
||||
Platform.STT: stt_platform_loaded,
|
||||
Platform.TTS: tts_platform_loaded,
|
||||
}
|
||||
|
||||
async def _on_start() -> None:
|
||||
"""Discover platforms."""
|
||||
|
@ -272,15 +282,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return
|
||||
loaded = True
|
||||
|
||||
stt_platform_loaded = asyncio.Event()
|
||||
tts_platform_loaded = asyncio.Event()
|
||||
stt_info = {"platform_loaded": stt_platform_loaded}
|
||||
tts_info = {"platform_loaded": tts_platform_loaded}
|
||||
|
||||
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
|
||||
await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config)
|
||||
await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config)
|
||||
await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait())
|
||||
await tts_platform_loaded.wait()
|
||||
|
||||
# The config entry should be loaded after the legacy tts platform is loaded
|
||||
# to make sure that the tts integration is setup before we try to migrate
|
||||
# old assist pipelines in the cloud stt entity.
|
||||
await hass.config_entries.flow.async_init(DOMAIN, context={"source": "system"})
|
||||
|
||||
async def _on_connect() -> None:
|
||||
"""Handle cloud connect."""
|
||||
|
@ -304,7 +315,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
cloud.register_on_initialized(_on_initialized)
|
||||
|
||||
await cloud.initialize()
|
||||
await http_api.async_setup(hass)
|
||||
http_api.async_setup(hass)
|
||||
|
||||
account_link.async_setup(hass)
|
||||
|
||||
|
@ -340,3 +351,19 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None:
|
|||
await cloud.remote.disconnect()
|
||||
|
||||
cloud.client.prefs.async_listen_updates(remote_prefs_updated)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
|
||||
stt_platform_loaded.set()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
return unload_ok
|
||||
|
|
|
@ -1,31 +1,48 @@
|
|||
"""Handle Cloud assist pipelines."""
|
||||
import asyncio
|
||||
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
async_create_default_pipeline,
|
||||
async_get_pipelines,
|
||||
async_setup_pipeline_store,
|
||||
async_update_pipeline,
|
||||
)
|
||||
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
|
||||
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
import homeassistant.helpers.entity_registry as er
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
|
||||
|
||||
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||
"""Create a cloud assist pipeline."""
|
||||
# Wait for stt and tts platforms to set up before creating the pipeline.
|
||||
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
||||
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
|
||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||
# is an after dependency of cloud
|
||||
await async_setup_pipeline_store(hass)
|
||||
|
||||
entity_registry = er.async_get(hass)
|
||||
new_stt_engine_id = entity_registry.async_get_entity_id(
|
||||
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
)
|
||||
if new_stt_engine_id is None:
|
||||
# If there's no cloud stt entity, we can't create a cloud pipeline.
|
||||
return None
|
||||
|
||||
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
||||
"""Return the ID of a cloud-enabled assist pipeline or None.
|
||||
|
||||
Check if a cloud pipeline already exists with
|
||||
legacy cloud engine id.
|
||||
Check if a cloud pipeline already exists with either
|
||||
legacy or current cloud engine ids.
|
||||
"""
|
||||
for pipeline in async_get_pipelines(hass):
|
||||
if (
|
||||
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
|
||||
and pipeline.stt_engine == DOMAIN
|
||||
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
|
||||
and pipeline.tts_engine == DOMAIN
|
||||
):
|
||||
return pipeline.id
|
||||
|
@ -34,7 +51,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
|||
if (cloud_assist_pipeline(hass)) is not None or (
|
||||
cloud_pipeline := await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id=DOMAIN,
|
||||
stt_engine_id=new_stt_engine_id,
|
||||
tts_engine_id=DOMAIN,
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
@ -42,3 +59,27 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
|||
return None
|
||||
|
||||
return cloud_pipeline.id
|
||||
|
||||
|
||||
async def async_migrate_cloud_pipeline_stt_engine(
|
||||
hass: HomeAssistant, stt_engine_id: str
|
||||
) -> None:
|
||||
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
|
||||
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
|
||||
# Added in 2024.01.0. Can be removed in 2025.01.0.
|
||||
|
||||
# We need to make sure that tts is loaded before this migration.
|
||||
# Assist pipeline will call default engine of tts when setting up the store.
|
||||
# Wait for the tts platform loaded event here.
|
||||
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
||||
await platforms_setup[Platform.TTS].wait()
|
||||
|
||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||
# is an after dependency of cloud
|
||||
await async_setup_pipeline_store(hass)
|
||||
|
||||
pipelines = async_get_pipelines(hass)
|
||||
for pipeline in pipelines:
|
||||
if pipeline.stt_engine != DOMAIN:
|
||||
continue
|
||||
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)
|
||||
|
|
23
homeassistant/components/cloud/config_flow.py
Normal file
23
homeassistant/components/cloud/config_flow.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
"""Config flow for the Cloud integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.config_entries import ConfigFlow
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
class CloudConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for the Cloud integration."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_system(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the system step."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
return self.async_create_entry(title="Home Assistant Cloud", data={})
|
|
@ -1,5 +1,6 @@
|
|||
"""Constants for the cloud component."""
|
||||
DOMAIN = "cloud"
|
||||
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
|
||||
REQUEST_TIMEOUT = 10
|
||||
|
||||
PREF_ENABLE_ALEXA = "alexa_enabled"
|
||||
|
@ -64,3 +65,5 @@ MODE_DEV = "development"
|
|||
MODE_PROD = "production"
|
||||
|
||||
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
|
||||
|
||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||
|
|
|
@ -28,7 +28,7 @@ from homeassistant.components.homeassistant import exposed_entities
|
|||
from homeassistant.components.http import HomeAssistantView, require_admin
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.util.location import async_detect_location_info
|
||||
|
@ -66,7 +66,8 @@ _CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
|
|||
}
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant) -> None:
|
||||
@callback
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Initialize the HTTP API."""
|
||||
websocket_api.async_register_command(hass, websocket_cloud_status)
|
||||
websocket_api.async_register_command(hass, websocket_subscription)
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
{
|
||||
"config": {
|
||||
"step": {},
|
||||
"abort": {
|
||||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
|
||||
}
|
||||
},
|
||||
"system_health": {
|
||||
"info": {
|
||||
"can_reach_cert_server": "Reach Certificate Server",
|
||||
|
|
|
@ -13,37 +13,38 @@ from homeassistant.components.stt import (
|
|||
AudioCodecs,
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
Provider,
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
SpeechResultState,
|
||||
SpeechToTextEntity,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
|
||||
from .client import CloudClient
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_get_engine(
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> CloudProvider:
|
||||
"""Set up Cloud speech component."""
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Home Assistant Cloud speech platform via config entry."""
|
||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
|
||||
cloud_provider = CloudProvider(cloud)
|
||||
if discovery_info is not None:
|
||||
discovery_info["platform_loaded"].set()
|
||||
return cloud_provider
|
||||
async_add_entities([CloudProviderEntity(cloud)])
|
||||
|
||||
|
||||
class CloudProvider(Provider):
|
||||
class CloudProviderEntity(SpeechToTextEntity):
|
||||
"""NabuCasa speech API provider."""
|
||||
|
||||
_attr_name = "Home Assistant Cloud"
|
||||
_attr_unique_id = STT_ENTITY_UNIQUE_ID
|
||||
|
||||
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
||||
"""Home Assistant NabuCasa Speech to text."""
|
||||
self.cloud = cloud
|
||||
|
@ -78,6 +79,10 @@ class CloudProvider(Provider):
|
|||
"""Return a list of supported channels."""
|
||||
return [AudioChannels.CHANNEL_MONO]
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity is about to be added to hass."""
|
||||
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
|
|
|
@ -29,9 +29,6 @@ _LOGGER = logging.getLogger(__name__)
|
|||
@callback
|
||||
def async_default_provider(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain of the default provider."""
|
||||
if "cloud" in hass.data[DATA_PROVIDERS]:
|
||||
return "cloud"
|
||||
|
||||
return next(iter(hass.data[DATA_PROVIDERS]), None)
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Websocket tests for Voice Assistant integration."""
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
|
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||
async_create_default_pipeline,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_update_pipeline,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -26,6 +28,13 @@ from .conftest import MockSttProvider, MockTTSProvider
|
|||
from tests.common import flush_store
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
||||
"""Load the homeassistant integration."""
|
||||
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def load_homeassistant(hass) -> None:
|
||||
"""Load the homeassistant integration."""
|
||||
|
@ -478,3 +487,125 @@ async def test_default_pipeline_unsupported_tts_language(
|
|||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_update_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_storage: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test async_update_pipeline."""
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipelines = async_get_pipelines(hass)
|
||||
pipelines = list(pipelines)
|
||||
assert pipelines == [
|
||||
Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language="en",
|
||||
id=ANY,
|
||||
language="en",
|
||||
name="Home Assistant",
|
||||
stt_engine=None,
|
||||
stt_language=None,
|
||||
tts_engine=None,
|
||||
tts_language=None,
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
]
|
||||
|
||||
pipeline = pipelines[0]
|
||||
await async_update_pipeline(
|
||||
hass,
|
||||
pipeline,
|
||||
conversation_engine="homeassistant_1",
|
||||
conversation_language="de",
|
||||
language="de",
|
||||
name="Home Assistant 1",
|
||||
stt_engine="stt.test_1",
|
||||
stt_language="de",
|
||||
tts_engine="test_1",
|
||||
tts_language="de",
|
||||
tts_voice="test_voice",
|
||||
wake_word_entity="wake_work.test_1",
|
||||
wake_word_id="wake_word_id_1",
|
||||
)
|
||||
|
||||
pipelines = async_get_pipelines(hass)
|
||||
pipelines = list(pipelines)
|
||||
pipeline = pipelines[0]
|
||||
assert pipelines == [
|
||||
Pipeline(
|
||||
conversation_engine="homeassistant_1",
|
||||
conversation_language="de",
|
||||
id=pipeline.id,
|
||||
language="de",
|
||||
name="Home Assistant 1",
|
||||
stt_engine="stt.test_1",
|
||||
stt_language="de",
|
||||
tts_engine="test_1",
|
||||
tts_language="de",
|
||||
tts_voice="test_voice",
|
||||
wake_word_entity="wake_work.test_1",
|
||||
wake_word_id="wake_word_id_1",
|
||||
)
|
||||
]
|
||||
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
|
||||
"conversation_engine": "homeassistant_1",
|
||||
"conversation_language": "de",
|
||||
"id": pipeline.id,
|
||||
"language": "de",
|
||||
"name": "Home Assistant 1",
|
||||
"stt_engine": "stt.test_1",
|
||||
"stt_language": "de",
|
||||
"tts_engine": "test_1",
|
||||
"tts_language": "de",
|
||||
"tts_voice": "test_voice",
|
||||
"wake_word_entity": "wake_work.test_1",
|
||||
"wake_word_id": "wake_word_id_1",
|
||||
}
|
||||
|
||||
await async_update_pipeline(
|
||||
hass,
|
||||
pipeline,
|
||||
stt_engine="stt.test_2",
|
||||
stt_language="en",
|
||||
tts_engine="test_2",
|
||||
tts_language="en",
|
||||
)
|
||||
|
||||
pipelines = async_get_pipelines(hass)
|
||||
pipelines = list(pipelines)
|
||||
assert pipelines == [
|
||||
Pipeline(
|
||||
conversation_engine="homeassistant_1",
|
||||
conversation_language="de",
|
||||
id=pipeline.id,
|
||||
language="de",
|
||||
name="Home Assistant 1",
|
||||
stt_engine="stt.test_2",
|
||||
stt_language="en",
|
||||
tts_engine="test_2",
|
||||
tts_language="en",
|
||||
tts_voice="test_voice",
|
||||
wake_word_entity="wake_work.test_1",
|
||||
wake_word_id="wake_word_id_1",
|
||||
)
|
||||
]
|
||||
assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
|
||||
"conversation_engine": "homeassistant_1",
|
||||
"conversation_language": "de",
|
||||
"id": pipeline.id,
|
||||
"language": "de",
|
||||
"name": "Home Assistant 1",
|
||||
"stt_engine": "stt.test_2",
|
||||
"stt_language": "en",
|
||||
"tts_engine": "test_2",
|
||||
"tts_language": "en",
|
||||
"tts_voice": "test_voice",
|
||||
"wake_word_entity": "wake_work.test_1",
|
||||
"wake_word_id": "wake_word_id_1",
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Tests for the cloud component."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
|
||||
from homeassistant.components import cloud
|
||||
from homeassistant.components.cloud import const, prefs as cloud_prefs
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -14,7 +15,7 @@ async def mock_cloud(hass, config=None):
|
|||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
assert await async_setup_component(hass, cloud.DOMAIN, {"cloud": config or {}})
|
||||
cloud_inst = hass.data["cloud"]
|
||||
cloud_inst: Cloud = hass.data["cloud"]
|
||||
with patch("hass_nabucasa.Cloud.run_executor", AsyncMock(return_value=None)):
|
||||
await cloud_inst.initialize()
|
||||
|
||||
|
|
40
tests/components/cloud/test_config_flow.py
Normal file
40
tests/components/cloud/test_config_flow.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
"""Test the Home Assistant Cloud config flow."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.cloud.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_config_flow(hass: HomeAssistant) -> None:
|
||||
"""Test create cloud entry."""
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.cloud.async_setup", return_value=True
|
||||
) as mock_setup, patch(
|
||||
"homeassistant.components.cloud.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry:
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": "system"}
|
||||
)
|
||||
assert result["type"] == "create_entry"
|
||||
assert result["title"] == "Home Assistant Cloud"
|
||||
assert result["data"] == {}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_multiple_entries(hass: HomeAssistant) -> None:
|
||||
"""Test creating multiple cloud entries."""
|
||||
config_entry = MockConfigEntry(domain=DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": "system"}
|
||||
)
|
||||
assert result["type"] == "abort"
|
||||
assert result["reason"] == "single_instance_allowed"
|
|
@ -46,6 +46,26 @@ PIPELINE_DATA_LEGACY = {
|
|||
"preferred_item": "12345",
|
||||
}
|
||||
|
||||
PIPELINE_DATA = {
|
||||
"items": [
|
||||
{
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": "language_1",
|
||||
"id": "12345",
|
||||
"language": "language_1",
|
||||
"name": "Home Assistant Cloud",
|
||||
"stt_engine": "stt.home_assistant_cloud",
|
||||
"stt_language": "language_1",
|
||||
"tts_engine": "cloud",
|
||||
"tts_language": "language_1",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
],
|
||||
"preferred_item": "12345",
|
||||
}
|
||||
|
||||
PIPELINE_DATA_OTHER = {
|
||||
"items": [
|
||||
{
|
||||
|
@ -127,7 +147,34 @@ async def test_google_actions_sync_fails(
|
|||
assert mock_request_sync.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA_LEGACY])
|
||||
async def test_login_view_missing_stt_entity(
|
||||
hass: HomeAssistant,
|
||||
setup_cloud: None,
|
||||
entity_registry: er.EntityRegistry,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test logging in when the cloud stt entity is missing."""
|
||||
# Make sure that the cloud stt entity does not exist.
|
||||
entity_registry.async_remove("stt.home_assistant_cloud")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
cloud_client = await hass_client()
|
||||
|
||||
# We assume the user needs to login again for some reason.
|
||||
with patch(
|
||||
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||
) as create_pipeline_mock:
|
||||
req = await cloud_client.post(
|
||||
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||
)
|
||||
|
||||
assert req.status == HTTPStatus.OK
|
||||
result = await req.json()
|
||||
assert result == {"success": True, "cloud_pipeline": None}
|
||||
create_pipeline_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipeline_data", [PIPELINE_DATA, PIPELINE_DATA_LEGACY])
|
||||
async def test_login_view_existing_pipeline(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
|
@ -195,7 +242,7 @@ async def test_login_view_create_pipeline(
|
|||
assert result == {"success": True, "cloud_pipeline": "12345"}
|
||||
create_pipeline_mock.assert_awaited_once_with(
|
||||
hass,
|
||||
stt_engine_id="cloud",
|
||||
stt_engine_id="stt.home_assistant_cloud",
|
||||
tts_engine_id="cloud",
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
@ -234,7 +281,7 @@ async def test_login_view_create_pipeline_fail(
|
|||
assert result == {"success": True, "cloud_pipeline": None}
|
||||
create_pipeline_mock.assert_awaited_once_with(
|
||||
hass,
|
||||
stt_engine_id="cloud",
|
||||
stt_engine_id="stt.home_assistant_cloud",
|
||||
tts_engine_id="cloud",
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
|
201
tests/components/cloud/test_stt.py
Normal file
201
tests/components/cloud/test_stt.py
Normal file
|
@ -0,0 +1,201 @@
|
|||
"""Test the speech-to-text platform for the cloud integration."""
|
||||
from collections.abc import AsyncGenerator
|
||||
from copy import deepcopy
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hass_nabucasa.voice import STTResponse, VoiceError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
||||
from homeassistant.components.cloud import DOMAIN
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
PIPELINE_DATA = {
|
||||
"items": [
|
||||
{
|
||||
"conversation_engine": "conversation_engine_1",
|
||||
"conversation_language": "language_1",
|
||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
"language": "language_1",
|
||||
"name": "Home Assistant Cloud",
|
||||
"stt_engine": "cloud",
|
||||
"stt_language": "language_1",
|
||||
"tts_engine": "cloud",
|
||||
"tts_language": "language_1",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_2",
|
||||
"conversation_language": "language_2",
|
||||
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||
"language": "language_2",
|
||||
"name": "name_2",
|
||||
"stt_engine": "stt_engine_2",
|
||||
"stt_language": "language_2",
|
||||
"tts_engine": "tts_engine_2",
|
||||
"tts_language": "language_2",
|
||||
"tts_voice": "The Voice",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_3",
|
||||
"conversation_language": "language_3",
|
||||
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||
"language": "language_3",
|
||||
"name": "name_3",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
],
|
||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def load_homeassistant(hass: HomeAssistant) -> None:
|
||||
"""Load the homeassistant integration."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
||||
"""Load the homeassistant integration."""
|
||||
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_process_stt", "expected_response_data"),
|
||||
[
|
||||
(
|
||||
AsyncMock(return_value=STTResponse(True, "Turn the Kitchen Lights on")),
|
||||
{"text": "Turn the Kitchen Lights on", "result": "success"},
|
||||
),
|
||||
(AsyncMock(side_effect=VoiceError("Boom!")), {"text": None, "result": "error"}),
|
||||
],
|
||||
)
|
||||
async def test_cloud_speech(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_process_stt: AsyncMock,
|
||||
expected_response_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test cloud text-to-speech."""
|
||||
cloud.voice.process_stt = mock_process_stt
|
||||
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||
await on_start_callback()
|
||||
|
||||
state = hass.states.get("stt.home_assistant_cloud")
|
||||
assert state
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
response = await client.post(
|
||||
"/api/stt/stt.home_assistant_cloud",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||
" language=de-DE"
|
||||
)
|
||||
},
|
||||
data=b"Test",
|
||||
)
|
||||
response_data = await response.json()
|
||||
|
||||
assert mock_process_stt.call_count == 1
|
||||
assert (
|
||||
mock_process_stt.call_args.kwargs["content_type"]
|
||||
== "audio/wav; codecs=audio/pcm; samplerate=16000"
|
||||
)
|
||||
assert mock_process_stt.call_args.kwargs["language"] == "de-DE"
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert response_data == expected_response_data
|
||||
|
||||
state = hass.states.get("stt.home_assistant_cloud")
|
||||
assert state
|
||||
assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
|
||||
|
||||
async def test_migrating_pipelines(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
hass_storage: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test migrating pipelines when cloud stt entity is added."""
|
||||
cloud.voice.process_stt = AsyncMock(
|
||||
return_value=STTResponse(True, "Turn the Kitchen Lights on")
|
||||
)
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"minor_version": 1,
|
||||
"key": "assist_pipeline.pipelines",
|
||||
"data": deepcopy(PIPELINE_DATA),
|
||||
}
|
||||
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||
await on_start_callback()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("stt.home_assistant_cloud")
|
||||
assert state
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
# The stt engine should be updated to the new cloud stt engine id.
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
|
||||
== "stt.home_assistant_cloud"
|
||||
)
|
||||
|
||||
# The other items should stay the same.
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"]
|
||||
== "conversation_engine_1"
|
||||
)
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"]
|
||||
== "language_1"
|
||||
)
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["id"]
|
||||
== "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1"
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud"
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
|
||||
== "Arnold Schwarzenegger"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1]
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2]
|
|
@ -121,12 +121,20 @@ class STTFlow(ConfigFlow):
|
|||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
@pytest.fixture(name="config_flow_test_domain")
|
||||
def config_flow_test_domain_fixture() -> str:
|
||||
"""Test domain fixture."""
|
||||
return TEST_DOMAIN
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, STTFlow):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(
|
||||
hass: HomeAssistant, config_flow_test_domain: str
|
||||
) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{config_flow_test_domain}.config_flow")
|
||||
|
||||
with mock_config_flow(config_flow_test_domain, STTFlow):
|
||||
yield
|
||||
|
||||
|
||||
|
@ -137,6 +145,7 @@ async def setup_fixture(
|
|||
request: pytest.FixtureRequest,
|
||||
) -> MockProvider | MockProviderEntity:
|
||||
"""Set up the test environment."""
|
||||
provider: MockProvider | MockProviderEntity
|
||||
if request.param == "mock_setup":
|
||||
provider = MockProvider()
|
||||
await mock_setup(hass, tmp_path, provider)
|
||||
|
@ -166,7 +175,10 @@ async def mock_setup(
|
|||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
test_domain: str = TEST_DOMAIN,
|
||||
) -> MockConfigEntry:
|
||||
"""Set up a test provider via config entry."""
|
||||
|
||||
|
@ -187,7 +199,7 @@ async def mock_config_entry_setup(
|
|||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
test_domain,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
|
@ -201,9 +213,9 @@ async def mock_config_entry_setup(
|
|||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_provider_entity])
|
||||
|
||||
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform)
|
||||
mock_stt_entity_platform(hass, tmp_path, test_domain, async_setup_entry_platform)
|
||||
|
||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
config_entry = MockConfigEntry(domain=test_domain)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -456,7 +468,11 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
|||
assert async_default_engine(hass) is None
|
||||
|
||||
|
||||
async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
async def test_default_engine(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider: MockProvider,
|
||||
) -> None:
|
||||
"""Test async_default_engine."""
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
|
@ -479,26 +495,31 @@ async def test_default_engine_entity(
|
|||
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
|
||||
|
||||
|
||||
async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
|
||||
async def test_default_engine_prefer_provider(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
mock_provider: MockProvider,
|
||||
config_flow_test_domain: str,
|
||||
) -> None:
|
||||
"""Test async_default_engine."""
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
TEST_DOMAIN,
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
"cloud",
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
assert await async_setup_component(
|
||||
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
||||
mock_provider_entity.url_path = "stt.new_test"
|
||||
mock_provider_entity._attr_name = "New test"
|
||||
|
||||
await mock_setup(hass, tmp_path, mock_provider)
|
||||
await mock_config_entry_setup(
|
||||
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert async_default_engine(hass) == "cloud"
|
||||
entity_engine = async_get_speech_to_text_engine(hass, "stt.new_test")
|
||||
assert entity_engine is not None
|
||||
assert entity_engine.name == "New test"
|
||||
provider_engine = async_get_speech_to_text_engine(hass, "test")
|
||||
assert provider_engine is not None
|
||||
assert provider_engine.name == "test"
|
||||
assert async_default_engine(hass) == "test"
|
||||
|
||||
|
||||
async def test_get_engine_legacy(
|
||||
|
|
Loading…
Add table
Reference in a new issue