Add migration logic to assist_pipeline (#115172)
This commit is contained in:
parent
cbaef096fa
commit
f9a7e6bb9f
8 changed files with 122 additions and 7 deletions
|
@ -31,6 +31,8 @@ from .pipeline import (
|
|||
async_create_default_pipeline,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_migrate_engine,
|
||||
async_run_migrations,
|
||||
async_setup_pipeline_store,
|
||||
async_update_pipeline,
|
||||
)
|
||||
|
@ -40,6 +42,7 @@ __all__ = (
|
|||
"DOMAIN",
|
||||
"async_create_default_pipeline",
|
||||
"async_get_pipelines",
|
||||
"async_migrate_engine",
|
||||
"async_setup",
|
||||
"async_pipeline_from_audio_stream",
|
||||
"async_update_pipeline",
|
||||
|
@ -72,6 +75,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass.data[DATA_LAST_WAKE_UP] = {}
|
||||
|
||||
await async_setup_pipeline_store(hass)
|
||||
await async_run_migrations(hass)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
return True
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
DOMAIN = "assist_pipeline"
|
||||
|
||||
DATA_CONFIG = f"{DOMAIN}.config"
|
||||
DATA_MIGRATIONS = f"{DOMAIN}_migrations"
|
||||
|
||||
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
from queue import Empty, Queue
|
||||
from threading import Thread
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, cast
|
||||
import wave
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -56,6 +56,7 @@ from .const import (
|
|||
CONF_DEBUG_RECORDING_DIR,
|
||||
DATA_CONFIG,
|
||||
DATA_LAST_WAKE_UP,
|
||||
DATA_MIGRATIONS,
|
||||
DOMAIN,
|
||||
WAKE_WORD_COOLDOWN,
|
||||
)
|
||||
|
@ -376,10 +377,6 @@ class Pipeline:
|
|||
This function was added in HA Core 2023.10, previous versions will raise
|
||||
if there are unexpected items in the serialized data.
|
||||
"""
|
||||
# Migrate to new value for conversation agent
|
||||
if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT:
|
||||
data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
return cls(
|
||||
conversation_engine=data["conversation_engine"],
|
||||
conversation_language=data["conversation_language"],
|
||||
|
@ -1818,3 +1815,47 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
|
|||
PIPELINE_FIELDS,
|
||||
).async_setup(hass)
|
||||
return PipelineData(pipeline_store)
|
||||
|
||||
|
||||
@callback
|
||||
def async_migrate_engine(
|
||||
hass: HomeAssistant,
|
||||
engine_type: Literal["conversation", "stt", "tts", "wake_word"],
|
||||
old_value: str,
|
||||
new_value: str,
|
||||
) -> None:
|
||||
"""Register a migration of an engine used in pipelines."""
|
||||
hass.data.setdefault(DATA_MIGRATIONS, {})[engine_type] = (old_value, new_value)
|
||||
|
||||
# Run migrations when config is already loaded
|
||||
if DATA_CONFIG in hass.data:
|
||||
hass.async_create_background_task(
|
||||
async_run_migrations(hass), "assist_pipeline_migration", eager_start=True
|
||||
)
|
||||
|
||||
|
||||
async def async_run_migrations(hass: HomeAssistant) -> None:
|
||||
"""Run pipeline migrations."""
|
||||
if not (migrations := hass.data.get(DATA_MIGRATIONS)):
|
||||
return
|
||||
|
||||
engine_attr = {
|
||||
"conversation": "conversation_engine",
|
||||
"stt": "stt_engine",
|
||||
"tts": "tts_engine",
|
||||
"wake_word": "wake_word_entity",
|
||||
}
|
||||
|
||||
updates = []
|
||||
|
||||
for pipeline in async_get_pipelines(hass):
|
||||
attr_updates = {}
|
||||
for engine_type, (old_value, new_value) in migrations.items():
|
||||
if getattr(pipeline, engine_attr[engine_type]) == old_value:
|
||||
attr_updates[engine_attr[engine_type]] = new_value
|
||||
|
||||
if attr_updates:
|
||||
updates.append((pipeline, attr_updates))
|
||||
|
||||
for pipeline, attr_updates in updates:
|
||||
await async_update_pipeline(hass, pipeline, **attr_updates)
|
||||
|
|
|
@ -43,8 +43,9 @@ __all__ = [
|
|||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
"async_set_agent",
|
||||
"async_unset_agent",
|
||||
"async_setup",
|
||||
"async_unset_agent",
|
||||
"ConversationEntity",
|
||||
"ConversationInput",
|
||||
"ConversationResult",
|
||||
]
|
||||
|
@ -188,6 +189,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
|
||||
)
|
||||
|
||||
# Temporary migration. We can remove this in 2024.10
|
||||
from homeassistant.components.assist_pipeline import ( # pylint: disable=import-outside-toplevel
|
||||
async_migrate_engine,
|
||||
)
|
||||
|
||||
async_migrate_engine(
|
||||
hass, "conversation", OLD_HOME_ASSISTANT_AGENT, HOME_ASSISTANT_AGENT
|
||||
)
|
||||
|
||||
async def handle_process(service: ServiceCall) -> ServiceResponse:
|
||||
"""Parse text into commands."""
|
||||
text = service.data[ATTR_TEXT]
|
||||
|
@ -227,3 +237,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
async_setup_conversation_http(hass)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
"dependencies": ["http", "intent"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/conversation",
|
||||
"integration_type": "system",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["hassil==1.6.1", "home-assistant-intents==2024.4.3"]
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ class Platform(StrEnum):
|
|||
CALENDAR = "calendar"
|
||||
CAMERA = "camera"
|
||||
CLIMATE = "climate"
|
||||
CONVERSATION = "conversation"
|
||||
COVER = "cover"
|
||||
DATE = "date"
|
||||
DATETIME = "datetime"
|
||||
|
|
|
@ -156,6 +156,8 @@ IGNORE_VIOLATIONS = {
|
|||
("websocket_api", "lovelace"),
|
||||
("websocket_api", "shopping_list"),
|
||||
"logbook",
|
||||
# Temporary needed for migration until 2024.10
|
||||
("conversation", "assist_pipeline"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||
async_create_default_pipeline,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_migrate_engine,
|
||||
async_update_pipeline,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -118,6 +119,12 @@ async def test_loading_pipelines_from_storage(
|
|||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test loading stored pipelines on start."""
|
||||
async_migrate_engine(
|
||||
hass,
|
||||
"conversation",
|
||||
conversation.OLD_HOME_ASSISTANT_AGENT,
|
||||
conversation.HOME_ASSISTANT_AGENT,
|
||||
)
|
||||
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": STORAGE_VERSION,
|
||||
|
@ -614,3 +621,41 @@ async def test_update_pipeline(
|
|||
"wake_word_entity": "wake_work.test_1",
|
||||
"wake_word_id": "wake_word_id_1",
|
||||
}
|
||||
|
||||
|
||||
async def test_migrate_after_load(
|
||||
hass: HomeAssistant, init_supporting_components
|
||||
) -> None:
|
||||
"""Test migrating an engine after done loading."""
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
assert (
|
||||
await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id="bla",
|
||||
tts_engine_id="bla",
|
||||
pipeline_name="Bla pipeline",
|
||||
)
|
||||
is None
|
||||
)
|
||||
pipeline = await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id="test",
|
||||
tts_engine_id="test",
|
||||
pipeline_name="Test pipeline",
|
||||
)
|
||||
assert pipeline is not None
|
||||
|
||||
async_migrate_engine(hass, "stt", "test", "stt.test")
|
||||
async_migrate_engine(hass, "tts", "test", "tts.test")
|
||||
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
pipeline_updated = async_get_pipeline(hass, pipeline.id)
|
||||
|
||||
assert pipeline_updated.stt_engine == "stt.test"
|
||||
assert pipeline_updated.tts_engine == "tts.test"
|
||||
|
|
Loading…
Add table
Reference in a new issue