Add migration logic to assist_pipeline (#115172)

This commit is contained in:
Paulus Schoutsen 2024-04-08 11:29:55 -04:00 committed by GitHub
parent cbaef096fa
commit f9a7e6bb9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 122 additions and 7 deletions

View file

@ -31,6 +31,8 @@ from .pipeline import (
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
async_migrate_engine,
async_run_migrations,
async_setup_pipeline_store,
async_update_pipeline,
)
@ -40,6 +42,7 @@ __all__ = (
"DOMAIN",
"async_create_default_pipeline",
"async_get_pipelines",
"async_migrate_engine",
"async_setup",
"async_pipeline_from_audio_stream",
"async_update_pipeline",
@ -72,6 +75,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.data[DATA_LAST_WAKE_UP] = {}
await async_setup_pipeline_store(hass)
await async_run_migrations(hass)
async_register_websocket_api(hass)
return True

View file

@ -3,6 +3,7 @@
DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config"
DATA_MIGRATIONS = f"{DOMAIN}_migrations"
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds

View file

@ -13,7 +13,7 @@ from pathlib import Path
from queue import Empty, Queue
from threading import Thread
import time
from typing import TYPE_CHECKING, Any, Final, cast
from typing import TYPE_CHECKING, Any, Final, Literal, cast
import wave
import voluptuous as vol
@ -56,6 +56,7 @@ from .const import (
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DATA_MIGRATIONS,
DOMAIN,
WAKE_WORD_COOLDOWN,
)
@ -376,10 +377,6 @@ class Pipeline:
This function was added in HA Core 2023.10, previous versions will raise
if there are unexpected items in the serialized data.
"""
# Migrate to new value for conversation agent
if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT:
data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT
return cls(
conversation_engine=data["conversation_engine"],
conversation_language=data["conversation_language"],
@ -1818,3 +1815,47 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
PIPELINE_FIELDS,
).async_setup(hass)
return PipelineData(pipeline_store)
@callback
def async_migrate_engine(
hass: HomeAssistant,
engine_type: Literal["conversation", "stt", "tts", "wake_word"],
old_value: str,
new_value: str,
) -> None:
"""Register a migration of an engine used in pipelines."""
hass.data.setdefault(DATA_MIGRATIONS, {})[engine_type] = (old_value, new_value)
# Run migrations when config is already loaded
if DATA_CONFIG in hass.data:
hass.async_create_background_task(
async_run_migrations(hass), "assist_pipeline_migration", eager_start=True
)
async def async_run_migrations(hass: HomeAssistant) -> None:
"""Run pipeline migrations."""
if not (migrations := hass.data.get(DATA_MIGRATIONS)):
return
engine_attr = {
"conversation": "conversation_engine",
"stt": "stt_engine",
"tts": "tts_engine",
"wake_word": "wake_word_entity",
}
updates = []
for pipeline in async_get_pipelines(hass):
attr_updates = {}
for engine_type, (old_value, new_value) in migrations.items():
if getattr(pipeline, engine_attr[engine_type]) == old_value:
attr_updates[engine_attr[engine_type]] = new_value
if attr_updates:
updates.append((pipeline, attr_updates))
for pipeline, attr_updates in updates:
await async_update_pipeline(hass, pipeline, **attr_updates)

View file

@ -43,8 +43,9 @@ __all__ = [
"async_converse",
"async_get_agent_info",
"async_set_agent",
"async_unset_agent",
"async_setup",
"async_unset_agent",
"ConversationEntity",
"ConversationInput",
"ConversationResult",
]
@ -188,6 +189,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
)
# Temporary migration. We can remove this in 2024.10
from homeassistant.components.assist_pipeline import ( # pylint: disable=import-outside-toplevel
async_migrate_engine,
)
async_migrate_engine(
hass, "conversation", OLD_HOME_ASSISTANT_AGENT, HOME_ASSISTANT_AGENT
)
async def handle_process(service: ServiceCall) -> ServiceResponse:
"""Parse text into commands."""
text = service.data[ATTR_TEXT]
@ -227,3 +237,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_setup_conversation_http(hass)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View file

@ -5,7 +5,6 @@
"dependencies": ["http", "intent"],
"documentation": "https://www.home-assistant.io/integrations/conversation",
"integration_type": "system",
"iot_class": "local_push",
"quality_scale": "internal",
"requirements": ["hassil==1.6.1", "home-assistant-intents==2024.4.3"]
}

View file

@ -44,6 +44,7 @@ class Platform(StrEnum):
CALENDAR = "calendar"
CAMERA = "camera"
CLIMATE = "climate"
CONVERSATION = "conversation"
COVER = "cover"
DATE = "date"
DATETIME = "datetime"

View file

@ -156,6 +156,8 @@ IGNORE_VIOLATIONS = {
("websocket_api", "lovelace"),
("websocket_api", "shopping_list"),
"logbook",
# Temporary needed for migration until 2024.10
("conversation", "assist_pipeline"),
}

View file

@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
async_migrate_engine,
async_update_pipeline,
)
from homeassistant.core import HomeAssistant
@ -118,6 +119,12 @@ async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored pipelines on start."""
async_migrate_engine(
hass,
"conversation",
conversation.OLD_HOME_ASSISTANT_AGENT,
conversation.HOME_ASSISTANT_AGENT,
)
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
hass_storage[STORAGE_KEY] = {
"version": STORAGE_VERSION,
@ -614,3 +621,41 @@ async def test_update_pipeline(
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
}
async def test_migrate_after_load(
hass: HomeAssistant, init_supporting_components
) -> None:
"""Test migrating an engine after done loading."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
assert (
await async_create_default_pipeline(
hass,
stt_engine_id="bla",
tts_engine_id="bla",
pipeline_name="Bla pipeline",
)
is None
)
pipeline = await async_create_default_pipeline(
hass,
stt_engine_id="test",
tts_engine_id="test",
pipeline_name="Test pipeline",
)
assert pipeline is not None
async_migrate_engine(hass, "stt", "test", "stt.test")
async_migrate_engine(hass, "tts", "test", "tts.test")
await hass.async_block_till_done(wait_background_tasks=True)
pipeline_updated = async_get_pipeline(hass, pipeline.id)
assert pipeline_updated.stt_engine == "stt.test"
assert pipeline_updated.tts_engine == "tts.test"