Add migration logic to assist_pipeline (#115172)

This commit is contained in:
Paulus Schoutsen 2024-04-08 11:29:55 -04:00 committed by GitHub
parent cbaef096fa
commit f9a7e6bb9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 122 additions and 7 deletions

View file

@ -31,6 +31,8 @@ from .pipeline import (
async_create_default_pipeline, async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines, async_get_pipelines,
async_migrate_engine,
async_run_migrations,
async_setup_pipeline_store, async_setup_pipeline_store,
async_update_pipeline, async_update_pipeline,
) )
@ -40,6 +42,7 @@ __all__ = (
"DOMAIN", "DOMAIN",
"async_create_default_pipeline", "async_create_default_pipeline",
"async_get_pipelines", "async_get_pipelines",
"async_migrate_engine",
"async_setup", "async_setup",
"async_pipeline_from_audio_stream", "async_pipeline_from_audio_stream",
"async_update_pipeline", "async_update_pipeline",
@ -72,6 +75,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.data[DATA_LAST_WAKE_UP] = {} hass.data[DATA_LAST_WAKE_UP] = {}
await async_setup_pipeline_store(hass) await async_setup_pipeline_store(hass)
await async_run_migrations(hass)
async_register_websocket_api(hass) async_register_websocket_api(hass)
return True return True

View file

@ -3,6 +3,7 @@
DOMAIN = "assist_pipeline" DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config" DATA_CONFIG = f"{DOMAIN}.config"
DATA_MIGRATIONS = f"{DOMAIN}_migrations"
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds

View file

@ -13,7 +13,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from threading import Thread from threading import Thread
import time import time
from typing import TYPE_CHECKING, Any, Final, cast from typing import TYPE_CHECKING, Any, Final, Literal, cast
import wave import wave
import voluptuous as vol import voluptuous as vol
@ -56,6 +56,7 @@ from .const import (
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG, DATA_CONFIG,
DATA_LAST_WAKE_UP, DATA_LAST_WAKE_UP,
DATA_MIGRATIONS,
DOMAIN, DOMAIN,
WAKE_WORD_COOLDOWN, WAKE_WORD_COOLDOWN,
) )
@ -376,10 +377,6 @@ class Pipeline:
This function was added in HA Core 2023.10, previous versions will raise This function was added in HA Core 2023.10, previous versions will raise
if there are unexpected items in the serialized data. if there are unexpected items in the serialized data.
""" """
# Migrate to new value for conversation agent
if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT:
data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT
return cls( return cls(
conversation_engine=data["conversation_engine"], conversation_engine=data["conversation_engine"],
conversation_language=data["conversation_language"], conversation_language=data["conversation_language"],
@ -1818,3 +1815,47 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
PIPELINE_FIELDS, PIPELINE_FIELDS,
).async_setup(hass) ).async_setup(hass)
return PipelineData(pipeline_store) return PipelineData(pipeline_store)
@callback
def async_migrate_engine(
hass: HomeAssistant,
engine_type: Literal["conversation", "stt", "tts", "wake_word"],
old_value: str,
new_value: str,
) -> None:
"""Register a migration of an engine used in pipelines."""
hass.data.setdefault(DATA_MIGRATIONS, {})[engine_type] = (old_value, new_value)
# Run migrations when config is already loaded
if DATA_CONFIG in hass.data:
hass.async_create_background_task(
async_run_migrations(hass), "assist_pipeline_migration", eager_start=True
)
async def async_run_migrations(hass: HomeAssistant) -> None:
"""Run pipeline migrations."""
if not (migrations := hass.data.get(DATA_MIGRATIONS)):
return
engine_attr = {
"conversation": "conversation_engine",
"stt": "stt_engine",
"tts": "tts_engine",
"wake_word": "wake_word_entity",
}
updates = []
for pipeline in async_get_pipelines(hass):
attr_updates = {}
for engine_type, (old_value, new_value) in migrations.items():
if getattr(pipeline, engine_attr[engine_type]) == old_value:
attr_updates[engine_attr[engine_type]] = new_value
if attr_updates:
updates.append((pipeline, attr_updates))
for pipeline, attr_updates in updates:
await async_update_pipeline(hass, pipeline, **attr_updates)

View file

@ -43,8 +43,9 @@ __all__ = [
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",
"async_set_agent", "async_set_agent",
"async_unset_agent",
"async_setup", "async_setup",
"async_unset_agent",
"ConversationEntity",
"ConversationInput", "ConversationInput",
"ConversationResult", "ConversationResult",
] ]
@ -188,6 +189,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass, entity_component, config.get(DOMAIN, {}).get("intents", {}) hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
) )
# Temporary migration. We can remove this in 2024.10
from homeassistant.components.assist_pipeline import ( # pylint: disable=import-outside-toplevel
async_migrate_engine,
)
async_migrate_engine(
hass, "conversation", OLD_HOME_ASSISTANT_AGENT, HOME_ASSISTANT_AGENT
)
async def handle_process(service: ServiceCall) -> ServiceResponse: async def handle_process(service: ServiceCall) -> ServiceResponse:
"""Parse text into commands.""" """Parse text into commands."""
text = service.data[ATTR_TEXT] text = service.data[ATTR_TEXT]
@ -227,3 +237,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_setup_conversation_http(hass) async_setup_conversation_http(hass)
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View file

@ -5,7 +5,6 @@
"dependencies": ["http", "intent"], "dependencies": ["http", "intent"],
"documentation": "https://www.home-assistant.io/integrations/conversation", "documentation": "https://www.home-assistant.io/integrations/conversation",
"integration_type": "system", "integration_type": "system",
"iot_class": "local_push",
"quality_scale": "internal", "quality_scale": "internal",
"requirements": ["hassil==1.6.1", "home-assistant-intents==2024.4.3"] "requirements": ["hassil==1.6.1", "home-assistant-intents==2024.4.3"]
} }

View file

@ -44,6 +44,7 @@ class Platform(StrEnum):
CALENDAR = "calendar" CALENDAR = "calendar"
CAMERA = "camera" CAMERA = "camera"
CLIMATE = "climate" CLIMATE = "climate"
CONVERSATION = "conversation"
COVER = "cover" COVER = "cover"
DATE = "date" DATE = "date"
DATETIME = "datetime" DATETIME = "datetime"

View file

@ -156,6 +156,8 @@ IGNORE_VIOLATIONS = {
("websocket_api", "lovelace"), ("websocket_api", "lovelace"),
("websocket_api", "shopping_list"), ("websocket_api", "shopping_list"),
"logbook", "logbook",
# Temporary needed for migration until 2024.10
("conversation", "assist_pipeline"),
} }

View file

@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_create_default_pipeline, async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines, async_get_pipelines,
async_migrate_engine,
async_update_pipeline, async_update_pipeline,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -118,6 +119,12 @@ async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any] hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None: ) -> None:
"""Test loading stored pipelines on start.""" """Test loading stored pipelines on start."""
async_migrate_engine(
hass,
"conversation",
conversation.OLD_HOME_ASSISTANT_AGENT,
conversation.HOME_ASSISTANT_AGENT,
)
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY" id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
hass_storage[STORAGE_KEY] = { hass_storage[STORAGE_KEY] = {
"version": STORAGE_VERSION, "version": STORAGE_VERSION,
@ -614,3 +621,41 @@ async def test_update_pipeline(
"wake_word_entity": "wake_work.test_1", "wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1", "wake_word_id": "wake_word_id_1",
} }
async def test_migrate_after_load(
hass: HomeAssistant, init_supporting_components
) -> None:
"""Test migrating an engine after done loading."""
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
assert (
await async_create_default_pipeline(
hass,
stt_engine_id="bla",
tts_engine_id="bla",
pipeline_name="Bla pipeline",
)
is None
)
pipeline = await async_create_default_pipeline(
hass,
stt_engine_id="test",
tts_engine_id="test",
pipeline_name="Test pipeline",
)
assert pipeline is not None
async_migrate_engine(hass, "stt", "test", "stt.test")
async_migrate_engine(hass, "tts", "test", "tts.test")
await hass.async_block_till_done(wait_background_tasks=True)
pipeline_updated = async_get_pipeline(hass, pipeline.id)
assert pipeline_updated.stt_engine == "stt.test"
assert pipeline_updated.tts_engine == "tts.test"