From d2e4f5f36e8e46cd0a8516ef92103de7ee2c2a59 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 1 Apr 2024 21:34:25 -0400 Subject: [PATCH] Add conversation entity (#114518) * Default agent as entity * Migrate constant to point at new value * Fix tests * Fix more tests * Move assist pipeline back to cloud after dependenceis --- .../components/assist_pipeline/pipeline.py | 4 + homeassistant/components/cloud/http_api.py | 5 +- homeassistant/components/cloud/stt.py | 16 +- homeassistant/components/cloud/tts.py | 17 +- .../components/conversation/__init__.py | 53 ++++-- .../components/conversation/agent_manager.py | 103 +++++------ .../components/conversation/const.py | 4 +- .../components/conversation/default_agent.py | 44 +++-- .../components/conversation/entity.py | 57 ++++++ homeassistant/components/conversation/http.py | 46 ++++- .../components/conversation/manifest.json | 2 +- .../components/conversation/models.py | 8 + .../components/conversation/trigger.py | 7 +- .../assist_pipeline/snapshots/test_init.ambr | 8 +- .../snapshots/test_websocket.ambr | 16 +- .../assist_pipeline/test_pipeline.py | 25 +-- .../assist_pipeline/test_websocket.py | 8 +- .../components/cloud/test_assist_pipeline.py | 2 + tests/components/cloud/test_http_api.py | 2 + tests/components/conversation/conftest.py | 13 +- .../conversation/snapshots/test_init.ambr | 166 +++++++++++++++++- .../conversation/test_default_agent.py | 12 +- tests/components/conversation/test_entity.py | 47 +++++ tests/components/conversation/test_init.py | 31 ++-- tests/components/conversation/test_trigger.py | 4 +- .../google_assistant_sdk/test_init.py | 5 +- .../conftest.py | 11 +- .../test_init.py | 4 +- tests/components/mobile_app/test_webhook.py | 2 +- tests/components/ollama/conftest.py | 6 + tests/components/ollama/test_init.py | 6 +- .../openai_conversation/conftest.py | 7 + .../openai_conversation/test_init.py | 2 +- 33 files changed, 566 insertions(+), 177 deletions(-) create mode 100644 homeassistant/components/conversation/entity.py create mode 100644 tests/components/conversation/test_entity.py diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 01a12b3635b..33e1b8c2f76 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -376,6 +376,10 @@ class Pipeline: This function was added in HA Core 2023.10, previous versions will raise if there are unexpected items in the serialized data. """ + # Migrate to new value for conversation agent + if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT: + data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT + return cls( conversation_engine=data["conversation_engine"], conversation_language=data["conversation_language"], diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index 8ca55876b28..b577e9de0d4 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -223,7 +223,10 @@ class CloudLoginView(HomeAssistantView): cloud: Cloud[CloudClient] = hass.data[DOMAIN] await cloud.login(data["email"], data["password"]) - new_cloud_pipeline_id = await async_create_cloud_pipeline(hass) + if "assist_pipeline" in hass.config.components: + new_cloud_pipeline_id = await async_create_cloud_pipeline(hass) + else: + new_cloud_pipeline_id = None return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id}) diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index d718cc5201e..c68e9f245ee 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -24,6 +24,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.setup import async_when_setup from .assist_pipeline import async_migrate_cloud_pipeline_engine from .client import CloudClient @@ -86,9 +87,18 @@ class CloudProviderEntity(SpeechToTextEntity): async def async_added_to_hass(self) -> None: """Run when entity is about to be added to hass.""" - await async_migrate_cloud_pipeline_engine( - self.hass, platform=Platform.STT, engine_id=self.entity_id - ) + + async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None: + """When assist_pipeline is set up.""" + assert self.platform.config_entry + self.platform.config_entry.async_create_task( + hass, + async_migrate_cloud_pipeline_engine( + self.hass, platform=Platform.STT, engine_id=self.entity_id + ), + ) + + async_when_setup(self.hass, "assist_pipeline", pipeline_setup) async def async_process_audio_stream( self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index 42e4b94a189..53cec74d133 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -27,6 +27,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.setup import async_when_setup from .assist_pipeline import async_migrate_cloud_pipeline_engine from .client import CloudClient @@ -156,9 +157,19 @@ class CloudTTSEntity(TextToSpeechEntity): async def async_added_to_hass(self) -> None: """Handle entity which will be added.""" await super().async_added_to_hass() - await async_migrate_cloud_pipeline_engine( - self.hass, platform=Platform.TTS, engine_id=self.entity_id - ) + + async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None: + """When assist_pipeline is set up.""" + assert self.platform.config_entry + self.platform.config_entry.async_create_task( + hass, + async_migrate_cloud_pipeline_engine( + self.hass, platform=Platform.TTS, engine_id=self.entity_id + ), + ) + + async_when_setup(self.hass, "assist_pipeline", pipeline_setup) + self.async_on_remove( self.cloud.client.prefs.async_listen_updates(self._sync_prefs) ) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index a0717ddaa58..63e0e9bff59 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Iterable import logging import re from typing import Literal @@ -20,6 +19,7 @@ from homeassistant.core import ( ) from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv, intent +from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass @@ -27,15 +27,19 @@ from .agent_manager import ( AgentInfo, agent_id_validator, async_converse, + async_get_agent, get_agent_manager, ) -from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT +from .const import HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT +from .default_agent import async_get_default_agent, async_setup_default_agent +from .entity import ConversationEntity from .http import async_setup as async_setup_conversation_http from .models import AbstractConversationAgent, ConversationInput, ConversationResult __all__ = [ "DOMAIN", "HOME_ASSISTANT_AGENT", + "OLD_HOME_ASSISTANT_AGENT", "async_converse", "async_get_agent_info", "async_set_agent", @@ -122,16 +126,26 @@ async def async_get_conversation_languages( all conversation agents. """ agent_manager = get_agent_manager(hass) + entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN] languages: set[str] = set() + agents: list[ConversationEntity | AbstractConversationAgent] + + if agent_id: + agent = async_get_agent(hass, agent_id) + + if agent is None: + raise ValueError(f"Agent {agent_id} not found") + + agents = [agent] - agent_ids: Iterable[str] - if agent_id is None: - agent_ids = iter(info.id for info in agent_manager.async_get_agent_info()) else: - agent_ids = (agent_id,) + agents = list(entity_component.entities) + for info in agent_manager.async_get_agent_info(): + agent = agent_manager.async_get_agent(info.id) + assert agent is not None + agents.append(agent) - for _agent_id in agent_ids: - agent = await agent_manager.async_get_agent(_agent_id) + for agent in agents: if agent.supported_languages == MATCH_ALL: return MATCH_ALL for language_tag in agent.supported_languages: @@ -146,10 +160,18 @@ def async_get_agent_info( agent_id: str | None = None, ) -> AgentInfo | None: """Get information on the agent or None if not found.""" - manager = get_agent_manager(hass) + agent = async_get_agent(hass, agent_id) - if agent_id is None: - agent_id = manager.default_agent + if agent is None: + return None + + if isinstance(agent, ConversationEntity): + name = agent.name + if not isinstance(name, str): + name = agent.entity_id + return AgentInfo(id=agent.entity_id, name=name) + + manager = get_agent_manager(hass) for agent_info in manager.async_get_agent_info(): if agent_info.id == agent_id: @@ -160,10 +182,11 @@ def async_get_agent_info( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" - agent_manager = get_agent_manager(hass) + entity_component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass) - if config_intents := config.get(DOMAIN, {}).get("intents"): - hass.data[DATA_CONFIG] = config_intents + await async_setup_default_agent( + hass, entity_component, config.get(DOMAIN, {}).get("intents", {}) + ) async def handle_process(service: ServiceCall) -> ServiceResponse: """Parse text into commands.""" @@ -188,7 +211,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def handle_reload(service: ServiceCall) -> None: """Reload intents.""" - agent = await agent_manager.async_get_agent() + agent = async_get_default_agent(hass) await agent.async_reload(language=service.data.get(ATTR_LANGUAGE)) hass.services.async_register( diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py index f34ecfaecc9..838539b4992 100644 --- a/homeassistant/components/conversation/agent_manager.py +++ b/homeassistant/components/conversation/agent_manager.py @@ -2,8 +2,6 @@ from __future__ import annotations -import asyncio -from dataclasses import dataclass import logging from typing import Any @@ -11,10 +9,17 @@ import voluptuous as vol from homeassistant.core import Context, HomeAssistant, async_get_hass, callback from homeassistant.helpers import config_validation as cv, singleton +from homeassistant.helpers.entity_component import EntityComponent -from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT -from .default_agent import DefaultAgent, async_setup as async_setup_default_agent -from .models import AbstractConversationAgent, ConversationInput, ConversationResult +from .const import DOMAIN, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT +from .default_agent import async_get_default_agent +from .entity import ConversationEntity +from .models import ( + AbstractConversationAgent, + AgentInfo, + ConversationInput, + ConversationResult, +) _LOGGER = logging.getLogger(__name__) @@ -23,20 +28,37 @@ _LOGGER = logging.getLogger(__name__) @callback def get_agent_manager(hass: HomeAssistant) -> AgentManager: """Get the active agent.""" - manager = AgentManager(hass) - manager.async_setup() - return manager + return AgentManager(hass) def agent_id_validator(value: Any) -> str: """Validate agent ID.""" hass = async_get_hass() - manager = get_agent_manager(hass) - if not manager.async_is_valid_agent_id(cv.string(value)): + if async_get_agent(hass, cv.string(value)) is None: raise vol.Invalid("invalid agent ID") return value +@callback +def async_get_agent( + hass: HomeAssistant, agent_id: str | None = None +) -> AbstractConversationAgent | ConversationEntity | None: + """Get specified agent.""" + if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT): + return async_get_default_agent(hass) + + if "." in agent_id: + entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN] + return entity_component.get_entity(agent_id) + + manager = get_agent_manager(hass) + + if not manager.async_is_valid_agent_id(agent_id): + return None + + return manager.async_get_agent(agent_id) + + async def async_converse( hass: HomeAssistant, text: str, @@ -47,13 +69,22 @@ async def async_converse( device_id: str | None = None, ) -> ConversationResult: """Process text and get intent.""" - agent = await get_agent_manager(hass).async_get_agent(agent_id) + agent = async_get_agent(hass, agent_id) + + if agent is None: + raise ValueError(f"Agent {agent_id} not found") + + if isinstance(agent, ConversationEntity): + agent.async_set_context(context) + method = agent.internal_async_process + else: + method = agent.async_process if language is None: language = hass.config.language _LOGGER.debug("Processing in %s: %s", language, text) - result = await agent.async_process( + result = await method( ConversationInput( text=text, context=context, @@ -65,52 +96,17 @@ async def async_converse( return result -@dataclass(frozen=True) -class AgentInfo: - """Container for conversation agent info.""" - - id: str - name: str - - class AgentManager: """Class to manage conversation agents.""" - default_agent: str = HOME_ASSISTANT_AGENT - _builtin_agent: AbstractConversationAgent | None = None - def __init__(self, hass: HomeAssistant) -> None: """Initialize the conversation agents.""" self.hass = hass self._agents: dict[str, AbstractConversationAgent] = {} - self._builtin_agent_init_lock = asyncio.Lock() - def async_setup(self) -> None: - """Set up the conversation agents.""" - async_setup_default_agent(self.hass) - - async def async_get_agent( - self, agent_id: str | None = None - ) -> AbstractConversationAgent: + @callback + def async_get_agent(self, agent_id: str) -> AbstractConversationAgent | None: """Get the agent.""" - if agent_id is None: - agent_id = self.default_agent - - if agent_id == HOME_ASSISTANT_AGENT: - if self._builtin_agent is not None: - return self._builtin_agent - - async with self._builtin_agent_init_lock: - if self._builtin_agent is not None: - return self._builtin_agent - - self._builtin_agent = DefaultAgent(self.hass) - await self._builtin_agent.async_initialize( - self.hass.data.get(DATA_CONFIG) - ) - - return self._builtin_agent - if agent_id not in self._agents: raise ValueError(f"Agent {agent_id} not found") @@ -119,12 +115,7 @@ class AgentManager: @callback def async_get_agent_info(self) -> list[AgentInfo]: """List all agents.""" - agents: list[AgentInfo] = [ - AgentInfo( - id=HOME_ASSISTANT_AGENT, - name="Home Assistant", - ) - ] + agents: list[AgentInfo] = [] for agent_id, agent in self._agents.items(): config_entry = self.hass.config_entries.async_get_entry(agent_id) @@ -148,7 +139,7 @@ class AgentManager: @callback def async_is_valid_agent_id(self, agent_id: str) -> bool: """Check if the agent id is valid.""" - return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT + return agent_id in self._agents @callback def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: diff --git a/homeassistant/components/conversation/const.py b/homeassistant/components/conversation/const.py index 5cb5ca3bdea..d20b6d96aa2 100644 --- a/homeassistant/components/conversation/const.py +++ b/homeassistant/components/conversation/const.py @@ -2,5 +2,5 @@ DOMAIN = "conversation" DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"} -HOME_ASSISTANT_AGENT = "homeassistant" -DATA_CONFIG = "conversation_config" +HOME_ASSISTANT_AGENT = "conversation.home_assistant" +OLD_HOME_ASSISTANT_AGENT = "homeassistant" diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 5a8d7b64eec..32ab7924916 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -24,7 +24,7 @@ from hassil.util import merge_dict from home_assistant_intents import ErrorKey, get_intents, get_languages import yaml -from homeassistant import core, setup +from homeassistant import core from homeassistant.components.homeassistant.exposed_entities import ( async_listen_entity_updates, async_should_expose, @@ -40,6 +40,7 @@ from homeassistant.helpers import ( template, translation, ) +from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import ( EventStateChangedData, async_track_state_added_domain, @@ -47,7 +48,8 @@ from homeassistant.helpers.event import ( from homeassistant.util.json import JsonObjectType, json_loads_object from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN -from .models import AbstractConversationAgent, ConversationInput, ConversationResult +from .entity import ConversationEntity +from .models import ConversationInput, ConversationResult _LOGGER = logging.getLogger(__name__) _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" @@ -60,6 +62,14 @@ TRIGGER_CALLBACK_TYPE = Callable[ METADATA_CUSTOM_SENTENCE = "hass_custom_sentence" METADATA_CUSTOM_FILE = "hass_custom_file" +DATA_DEFAULT_ENTITY = "conversation_default_entity" + + +@core.callback +def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent: + """Get the default agent.""" + return hass.data[DATA_DEFAULT_ENTITY] + def json_load(fp: IO[str]) -> JsonObjectType: """Wrap json_loads for get_intents.""" @@ -109,9 +119,16 @@ def _get_language_variations(language: str) -> Iterable[str]: yield lang -@core.callback -def async_setup(hass: core.HomeAssistant) -> None: +async def async_setup_default_agent( + hass: core.HomeAssistant, + entity_component: EntityComponent[ConversationEntity], + config_intents: dict[str, Any], +) -> None: """Set up entity registry listener for the default agent.""" + entity = DefaultAgent(hass, config_intents) + await entity_component.async_add_entities([entity]) + hass.data[DATA_DEFAULT_ENTITY] = entity + entity_registry = er.async_get(hass) for entity_id in entity_registry.entities: async_should_expose(hass, DOMAIN, entity_id) @@ -131,17 +148,21 @@ def async_setup(hass: core.HomeAssistant) -> None: start.async_at_started(hass, async_hass_started) -class DefaultAgent(AbstractConversationAgent): +class DefaultAgent(ConversationEntity): """Default agent for conversation agent.""" - def __init__(self, hass: core.HomeAssistant) -> None: + _attr_name = "Home Assistant" + + def __init__( + self, hass: core.HomeAssistant, config_intents: dict[str, Any] + ) -> None: """Initialize the default agent.""" self.hass = hass self._lang_intents: dict[str, LanguageIntents] = {} self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # intent -> [sentences] - self._config_intents: dict[str, Any] = {} + self._config_intents: dict[str, Any] = config_intents self._slot_lists: dict[str, SlotList] | None = None # Sentences that will trigger a callback (skipping intent recognition) @@ -154,15 +175,6 @@ class DefaultAgent(AbstractConversationAgent): """Return a list of supported languages.""" return get_languages() - async def async_initialize(self, config_intents: dict[str, Any] | None) -> None: - """Initialize the default agent.""" - if "intent" not in self.hass.config.components: - await setup.async_setup_component(self.hass, "intent", {}) - - # Intents from config may only contains sentences for HA config's language - if config_intents: - self._config_intents = config_intents - @core.callback def _filter_entity_registry_changes(self, event_data: dict[str, Any]) -> bool: """Filter entity registry changed events.""" diff --git a/homeassistant/components/conversation/entity.py b/homeassistant/components/conversation/entity.py new file mode 100644 index 00000000000..12dbea41344 --- /dev/null +++ b/homeassistant/components/conversation/entity.py @@ -0,0 +1,57 @@ +"""Entity for conversation integration.""" + +from abc import abstractmethod +from typing import Literal, final + +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.util import dt as dt_util + +from .models import ConversationInput, ConversationResult + + +class ConversationEntity(RestoreEntity): + """Entity that supports conversations.""" + + _attr_should_poll = False + __last_activity: str | None = None + + @property + @final + def state(self) -> str | None: + """Return the state of the entity.""" + if self.__last_activity is None: + return None + return self.__last_activity + + async def async_internal_added_to_hass(self) -> None: + """Call when the entity is added to hass.""" + await super().async_internal_added_to_hass() + state = await self.async_get_last_state() + if ( + state is not None + and state.state is not None + and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + self.__last_activity = state.state + + @final + async def internal_async_process( + self, user_input: ConversationInput + ) -> ConversationResult: + """Process a sentence.""" + self.__last_activity = dt_util.utcnow().isoformat() + self.async_write_ha_state() + return await self.async_process(user_input) + + @property + @abstractmethod + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + + @abstractmethod + async def async_process(self, user_input: ConversationInput) -> ConversationResult: + """Process a sentence.""" + + async def async_prepare(self, language: str | None = None) -> None: + """Load intents for a language.""" diff --git a/homeassistant/components/conversation/http.py b/homeassistant/components/conversation/http.py index fb67d686b23..beda7ba1550 100644 --- a/homeassistant/components/conversation/http.py +++ b/homeassistant/components/conversation/http.py @@ -19,16 +19,24 @@ from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.const import MATCH_ALL from homeassistant.core import HomeAssistant, State, callback from homeassistant.helpers import config_validation as cv, intent +from homeassistant.helpers.entity_component import EntityComponent from homeassistant.util import language as language_util -from .agent_manager import agent_id_validator, async_converse, get_agent_manager -from .const import HOME_ASSISTANT_AGENT +from .agent_manager import ( + agent_id_validator, + async_converse, + async_get_agent, + get_agent_manager, +) +from .const import DOMAIN from .default_agent import ( METADATA_CUSTOM_FILE, METADATA_CUSTOM_SENTENCE, DefaultAgent, SentenceTriggerResult, + async_get_default_agent, ) +from .entity import ConversationEntity from .models import ConversationInput @@ -83,8 +91,14 @@ async def websocket_prepare( msg: dict[str, Any], ) -> None: """Reload intents.""" - manager = get_agent_manager(hass) - agent = await manager.async_get_agent(msg.get("agent_id")) + agent = async_get_agent(hass, msg.get("agent_id")) + + if agent is None: + connection.send_error( + msg["id"], websocket_api.const.ERR_NOT_FOUND, "Agent not found" + ) + return + await agent.async_prepare(msg.get("language")) connection.send_result(msg["id"]) @@ -101,14 +115,32 @@ async def websocket_list_agents( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """List conversation agents and, optionally, if they support a given language.""" - manager = get_agent_manager(hass) + entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN] country = msg.get("country") language = msg.get("language") agents = [] + for entity in entity_component.entities: + supported_languages = entity.supported_languages + if language and supported_languages != MATCH_ALL: + supported_languages = language_util.matches( + language, supported_languages, country + ) + + agents.append( + { + "id": entity.entity_id, + "name": entity.name or entity.entity_id, + "supported_languages": supported_languages, + } + ) + + manager = get_agent_manager(hass) + for agent_info in manager.async_get_agent_info(): - agent = await manager.async_get_agent(agent_info.id) + agent = manager.async_get_agent(agent_info.id) + assert agent is not None supported_languages = agent.supported_languages if language and supported_languages != MATCH_ALL: @@ -139,7 +171,7 @@ async def websocket_hass_agent_debug( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Return intents that would be matched by the default agent for a list of sentences.""" - agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) + agent = async_get_default_agent(hass) assert isinstance(agent, DefaultAgent) results = [ await agent.async_recognize( diff --git a/homeassistant/components/conversation/manifest.json b/homeassistant/components/conversation/manifest.json index 7f463483bf9..07fc86313ba 100644 --- a/homeassistant/components/conversation/manifest.json +++ b/homeassistant/components/conversation/manifest.json @@ -2,7 +2,7 @@ "domain": "conversation", "name": "Conversation", "codeowners": ["@home-assistant/core", "@synesthesiam"], - "dependencies": ["http"], + "dependencies": ["http", "intent"], "documentation": "https://www.home-assistant.io/integrations/conversation", "integration_type": "system", "iot_class": "local_push", diff --git a/homeassistant/components/conversation/models.py b/homeassistant/components/conversation/models.py index 22b3437907c..3fd24152698 100644 --- a/homeassistant/components/conversation/models.py +++ b/homeassistant/components/conversation/models.py @@ -10,6 +10,14 @@ from homeassistant.core import Context from homeassistant.helpers import intent +@dataclass(frozen=True) +class AgentInfo: + """Container for conversation agent info.""" + + id: str + name: str + + @dataclass(slots=True) class ConversationInput: """User input to be processed.""" diff --git a/homeassistant/components/conversation/trigger.py b/homeassistant/components/conversation/trigger.py index 05fea054bca..0a4cbfcb7e5 100644 --- a/homeassistant/components/conversation/trigger.py +++ b/homeassistant/components/conversation/trigger.py @@ -14,9 +14,8 @@ from homeassistant.helpers.script import ScriptRunResult from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import UNDEFINED, ConfigType -from .agent_manager import get_agent_manager -from .const import DOMAIN, HOME_ASSISTANT_AGENT -from .default_agent import DefaultAgent +from .const import DOMAIN +from .default_agent import DefaultAgent, async_get_default_agent def has_no_punctuation(value: list[str]) -> list[str]: @@ -111,7 +110,7 @@ async def async_attach_trigger( # two trigger copies for who will provide a response. return None - default_agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT) + default_agent = async_get_default_agent(hass) assert isinstance(default_agent, DefaultAgent) return default_agent.register_trigger(sentences, call_action) diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index bbd0c9d333a..8124ed4ab85 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -34,7 +34,7 @@ 'data': dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en', }), @@ -123,7 +123,7 @@ 'data': dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en-US', }), @@ -212,7 +212,7 @@ 'data': dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en-US', }), @@ -325,7 +325,7 @@ 'data': dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en', }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 10a76bc9344..f952e3b7286 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -33,7 +33,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en', }) @@ -114,7 +114,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en', }) @@ -207,7 +207,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en', }) @@ -409,7 +409,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', 'language': 'en', }) @@ -615,7 +615,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'Are the lights on?', 'language': 'en', }) @@ -637,7 +637,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'Are the lights on?', 'language': 'en', }) @@ -665,7 +665,7 @@ dict({ 'conversation_id': None, 'device_id': None, - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'never mind', 'language': 'en', }) @@ -799,7 +799,7 @@ dict({ 'conversation_id': 'mock-conversation-id', 'device_id': 'mock-device-id', - 'engine': 'homeassistant', + 'engine': 'conversation.home_assistant', 'intent_input': 'Are the lights on?', 'language': 'en', }) diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 3bfe6605839..3588bba6416 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -6,6 +6,7 @@ from unittest.mock import ANY, patch import pytest +from homeassistant.components import conversation from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, @@ -117,6 +118,7 @@ async def test_loading_pipelines_from_storage( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: """Test loading stored pipelines on start.""" + id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY" hass_storage[STORAGE_KEY] = { "version": STORAGE_VERSION, "minor_version": STORAGE_VERSION_MINOR, @@ -124,9 +126,9 @@ async def test_loading_pipelines_from_storage( "data": { "items": [ { - "conversation_engine": "conversation_engine_1", + "conversation_engine": conversation.OLD_HOME_ASSISTANT_AGENT, "conversation_language": "language_1", - "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + "id": id_1, "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", @@ -166,7 +168,7 @@ async def test_loading_pipelines_from_storage( "wake_word_id": "wakeword_id_3", }, ], - "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + "preferred_item": id_1, }, } @@ -175,7 +177,8 @@ async def test_loading_pipelines_from_storage( pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store assert len(store.data) == 3 - assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" + assert store.async_get_preferred_item() == id_1 + assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT async def test_migrate_pipeline_store( @@ -262,7 +265,7 @@ async def test_create_default_pipeline( tts_engine_id="test", pipeline_name="Test pipeline", ) == Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", @@ -304,7 +307,7 @@ async def test_get_pipelines(hass: HomeAssistant) -> None: pipelines = async_get_pipelines(hass) assert list(pipelines) == [ Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", @@ -351,7 +354,7 @@ async def test_default_pipeline_no_stt_tts( # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language=conv_language, id=pipeline.id, language=pipeline_language, @@ -414,7 +417,7 @@ async def test_default_pipeline( # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language=conv_language, id=pipeline.id, language=pipeline_language, @@ -445,7 +448,7 @@ async def test_default_pipeline_unsupported_stt_language( # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language="en", id=pipeline.id, language="en", @@ -476,7 +479,7 @@ async def test_default_pipeline_unsupported_tts_language( # Check the default pipeline pipeline = async_get_pipeline(hass, None) assert pipeline == Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language="en", id=pipeline.id, language="en", @@ -502,7 +505,7 @@ async def test_update_pipeline( pipelines = list(pipelines) assert pipelines == [ Pipeline( - conversation_engine="homeassistant", + conversation_engine="conversation.home_assistant", conversation_language="en", id=ANY, language="en", diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 0883046f3a1..e08dd9685ea 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -1166,7 +1166,7 @@ async def test_get_pipeline( msg = await client.receive_json() assert msg["success"] assert msg["result"] == { - "conversation_engine": "homeassistant", + "conversation_engine": "conversation.home_assistant", "conversation_language": "en", "id": ANY, "language": "en", @@ -1250,7 +1250,7 @@ async def test_list_pipelines( assert msg["result"] == { "pipelines": [ { - "conversation_engine": "homeassistant", + "conversation_engine": "conversation.home_assistant", "conversation_language": "en", "id": ANY, "language": "en", @@ -2012,7 +2012,7 @@ async def test_wake_word_cooldown_different_entities( await client_pipeline.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", - "conversation_engine": "homeassistant", + "conversation_engine": "conversation.home_assistant", "conversation_language": "en-US", "language": "en", "name": "pipeline_with_wake_word_1", @@ -2032,7 +2032,7 @@ async def test_wake_word_cooldown_different_entities( await client_pipeline.send_json_auto_id( { "type": "assist_pipeline/pipeline/create", - "conversation_engine": "homeassistant", + "conversation_engine": "conversation.home_assistant", "conversation_language": "en-US", "language": "en", "name": "pipeline_with_wake_word_2", diff --git a/tests/components/cloud/test_assist_pipeline.py b/tests/components/cloud/test_assist_pipeline.py index 5c2fc074898..de30212c040 100644 --- a/tests/components/cloud/test_assist_pipeline.py +++ b/tests/components/cloud/test_assist_pipeline.py @@ -7,10 +7,12 @@ from homeassistant.components.cloud.assist_pipeline import ( ) from homeassistant.const import Platform from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None: """Test migrate pipeline with invalid platform.""" + await async_setup_component(hass, "assist_pipeline", {}) with pytest.raises(ValueError): await async_migrate_cloud_pipeline_engine( hass, Platform.BINARY_SENSOR, "test-engine-id" diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 0dad7cfa882..5ee9af88681 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -231,6 +231,7 @@ async def test_login_view_create_pipeline( } assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "assist_pipeline", {}) assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) await hass.async_block_till_done() @@ -270,6 +271,7 @@ async def test_login_view_create_pipeline_fail( } assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "assist_pipeline", {}) assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) await hass.async_block_till_done() diff --git a/tests/components/conversation/conftest.py b/tests/components/conversation/conftest.py index d6c2d9e2e5e..4801e506460 100644 --- a/tests/components/conversation/conftest.py +++ b/tests/components/conversation/conftest.py @@ -7,6 +7,8 @@ import pytest from homeassistant.components import conversation from homeassistant.components.shopping_list import intent as sl_intent from homeassistant.const import MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component from . import MockAgent @@ -14,7 +16,7 @@ from tests.common import MockConfigEntry @pytest.fixture -def mock_agent_support_all(hass): +def mock_agent_support_all(hass: HomeAssistant): """Mock agent that supports all languages.""" entry = MockConfigEntry(entry_id="mock-entry-support-all") entry.add_to_hass(hass) @@ -34,7 +36,7 @@ def mock_shopping_list_io(): @pytest.fixture -async def sl_setup(hass): +async def sl_setup(hass: HomeAssistant): """Set up the shopping list.""" entry = MockConfigEntry(domain="shopping_list") @@ -43,3 +45,10 @@ async def sl_setup(hass): assert await hass.config_entries.async_setup(entry.entry_id) await sl_intent.async_setup_intents(hass) + + +@pytest.fixture +async def init_components(hass: HomeAssistant): + """Initialize relevant components with empty configs.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "conversation", {}) diff --git a/tests/components/conversation/snapshots/test_init.ambr b/tests/components/conversation/snapshots/test_init.ambr index 6af9d197e01..d514d145477 100644 --- a/tests/components/conversation/snapshots/test_init.ambr +++ b/tests/components/conversation/snapshots/test_init.ambr @@ -101,7 +101,7 @@ # --- # name: test_get_agent_info dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', }) # --- @@ -113,7 +113,7 @@ # --- # name: test_get_agent_info.2 dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', }) # --- @@ -127,7 +127,7 @@ dict({ 'agents': list([ dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', 'supported_languages': list([ 'af', @@ -207,7 +207,7 @@ dict({ 'agents': list([ dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', 'supported_languages': list([ ]), @@ -231,7 +231,7 @@ dict({ 'agents': list([ dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', 'supported_languages': list([ 'en', @@ -255,7 +255,7 @@ dict({ 'agents': list([ dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', 'supported_languages': list([ 'en', @@ -279,7 +279,7 @@ dict({ 'agents': list([ dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', 'supported_languages': list([ 'de', @@ -304,7 +304,7 @@ dict({ 'agents': list([ dict({ - 'id': 'homeassistant', + 'id': 'conversation.home_assistant', 'name': 'Home Assistant', 'supported_languages': list([ 'de-CH', @@ -415,6 +415,36 @@ }), }) # --- +# name: test_http_processing_intent[conversation.home_assistant] + dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + dict({ + 'id': 'light.kitchen', + 'name': 'kitchen', + 'type': 'entity', + }), + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Turned on the light', + }), + }), + }), + }) +# --- # name: test_http_processing_intent[homeassistant] dict({ 'conversation_id': None, @@ -1035,6 +1065,36 @@ }), }) # --- +# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant] + dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + dict({ + 'id': 'light.kitchen', + 'name': 'kitchen', + 'type': , + }), + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Turned on the light', + }), + }), + }), + }) +# --- # name: test_turn_on_intent[None-turn kitchen on-homeassistant] dict({ 'conversation_id': None, @@ -1095,6 +1155,36 @@ }), }) # --- +# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant] + dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + dict({ + 'id': 'light.kitchen', + 'name': 'kitchen', + 'type': , + }), + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Turned on the light', + }), + }), + }), + }) +# --- # name: test_turn_on_intent[None-turn on kitchen-homeassistant] dict({ 'conversation_id': None, @@ -1155,6 +1245,36 @@ }), }) # --- +# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant] + dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + dict({ + 'id': 'light.kitchen', + 'name': 'kitchen', + 'type': , + }), + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Turned on the light', + }), + }), + }), + }) +# --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant] dict({ 'conversation_id': None, @@ -1215,6 +1335,36 @@ }), }) # --- +# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant] + dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + dict({ + 'id': 'light.kitchen', + 'name': 'kitchen', + 'type': , + }), + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Turned on the light', + }), + }), + }), + }) +# --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant] dict({ 'conversation_id': None, diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index c600c71711e..474198cb8a3 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -7,7 +7,7 @@ from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult import pytest from homeassistant.components import conversation -from homeassistant.components.conversation import agent_manager, default_agent +from homeassistant.components.conversation import default_agent from homeassistant.components.homeassistant.exposed_entities import ( async_get_assistant_settings, ) @@ -152,9 +152,7 @@ async def test_conversation_agent( init_components, ) -> None: """Test DefaultAgent.""" - agent = await agent_manager.get_agent_manager(hass).async_get_agent( - conversation.HOME_ASSISTANT_AGENT - ) + agent = default_agent.async_get_default_agent(hass) with patch( "homeassistant.components.conversation.default_agent.get_languages", return_value=["dwarvish", "elvish", "entish"], @@ -181,6 +179,7 @@ async def test_expose_flag_automatically_set( # After setting up conversation, the expose flag should now be set on all entities assert async_get_assistant_settings(hass, conversation.DOMAIN) == { + "conversation.home_assistant": {"should_expose": False}, light.entity_id: {"should_expose": True}, test.entity_id: {"should_expose": False}, } @@ -190,6 +189,7 @@ async def test_expose_flag_automatically_set( hass.states.async_set(new_light, "test") await hass.async_block_till_done() assert async_get_assistant_settings(hass, conversation.DOMAIN) == { + "conversation.home_assistant": {"should_expose": False}, light.entity_id: {"should_expose": True}, new_light: {"should_expose": True}, test.entity_id: {"should_expose": False}, @@ -254,9 +254,7 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None: trigger_sentences = ["It's party time", "It is time to party"] trigger_response = "Cowabunga!" - agent = await agent_manager.get_agent_manager(hass).async_get_agent( - conversation.HOME_ASSISTANT_AGENT - ) + agent = default_agent.async_get_default_agent(hass) assert isinstance(agent, default_agent.DefaultAgent) callback = AsyncMock(return_value=trigger_response) diff --git a/tests/components/conversation/test_entity.py b/tests/components/conversation/test_entity.py new file mode 100644 index 00000000000..c84f94c4aa4 --- /dev/null +++ b/tests/components/conversation/test_entity.py @@ -0,0 +1,47 @@ +"""Tests for conversation entity.""" + +from unittest.mock import patch + +from homeassistant.core import Context, HomeAssistant, State +from homeassistant.setup import async_setup_component +import homeassistant.util.dt as dt_util + +from tests.common import mock_restore_cache + + +async def test_state_set_and_restore(hass: HomeAssistant) -> None: + """Test we set and restore state in the integration.""" + entity_id = "conversation.home_assistant" + timestamp = "2023-01-01T23:59:59+00:00" + mock_restore_cache(hass, (State(entity_id, timestamp),)) + + await async_setup_component(hass, "homeassistant", {}) + await async_setup_component(hass, "conversation", {}) + + state = hass.states.get(entity_id) + assert state + assert state.state == timestamp + + now = dt_util.utcnow() + context = Context() + + with ( + patch( + "homeassistant.components.conversation.default_agent.DefaultAgent.async_process" + ) as mock_process, + patch("homeassistant.util.dt.utcnow", return_value=now), + ): + await hass.services.async_call( + "conversation", + "process", + {"text": "Hello"}, + context=context, + blocking=True, + ) + + assert len(mock_process.mock_calls) == 1 + + state = hass.states.get(entity_id) + assert state + assert state.state == now.isoformat() + assert state.context is context diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 62f67548ece..5b117c1ac70 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -9,7 +9,7 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation -from homeassistant.components.conversation import agent_manager, default_agent +from homeassistant.components.conversation import default_agent from homeassistant.components.conversation.models import ConversationInput from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN @@ -35,7 +35,13 @@ from tests.common import ( from tests.components.light.common import MockLight from tests.typing import ClientSessionGenerator, WebSocketGenerator -AGENT_ID_OPTIONS = [None, conversation.HOME_ASSISTANT_AGENT] +AGENT_ID_OPTIONS = [ + None, + # Old value of conversation.HOME_ASSISTANT_AGENT, + "homeassistant", + # Current value of conversation.HOME_ASSISTANT_AGENT, + "conversation.home_assistant", +] class OrderBeerIntentHandler(intent.IntentHandler): @@ -51,14 +57,6 @@ class OrderBeerIntentHandler(intent.IntentHandler): return response -@pytest.fixture -async def init_components(hass): - """Initialize relevant components with empty configs.""" - assert await async_setup_component(hass, "homeassistant", {}) - assert await async_setup_component(hass, "conversation", {}) - assert await async_setup_component(hass, "intent", {}) - - @pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) async def test_http_processing_intent( hass: HomeAssistant, @@ -752,7 +750,7 @@ async def test_ws_prepare( """Test the Websocket prepare conversation API.""" assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) - agent = await agent_manager.get_agent_manager(hass).async_get_agent() + agent = default_agent.async_get_default_agent(hass) assert isinstance(agent, default_agent.DefaultAgent) # No intents should be loaded yet @@ -854,7 +852,7 @@ async def test_prepare_reload(hass: HomeAssistant) -> None: assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await agent_manager.get_agent_manager(hass).async_get_agent() + agent = default_agent.async_get_default_agent(hass) assert isinstance(agent, default_agent.DefaultAgent) await agent.async_prepare(language) @@ -882,7 +880,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None: assert await async_setup_component(hass, "conversation", {}) # Load intents - agent = await agent_manager.get_agent_manager(hass).async_get_agent() + agent = default_agent.async_get_default_agent(hass) assert isinstance(agent, default_agent.DefaultAgent) await agent.async_prepare("not-a-language") @@ -919,7 +917,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non hass.states.async_set("cover.front_door", "closed") calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER) - agent = await agent_manager.get_agent_manager(hass).async_get_agent() + agent = default_agent.async_get_default_agent(hass) assert isinstance(agent, default_agent.DefaultAgent) result = await agent.async_process( @@ -1063,12 +1061,15 @@ async def test_light_area_same_name( assert call.data == {"entity_id": [kitchen_light.entity_id]} -async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None: +async def test_agent_id_validator_invalid_agent( + hass: HomeAssistant, init_components +) -> None: """Test validating agent id.""" with pytest.raises(vol.Invalid): conversation.agent_id_validator("invalid_agent") conversation.agent_id_validator(conversation.HOME_ASSISTANT_AGENT) + conversation.agent_id_validator("conversation.home_assistant") async def test_get_agent_list( diff --git a/tests/components/conversation/test_trigger.py b/tests/components/conversation/test_trigger.py index 33ad8efdd2e..9e78b9b6180 100644 --- a/tests/components/conversation/test_trigger.py +++ b/tests/components/conversation/test_trigger.py @@ -5,7 +5,7 @@ import logging import pytest import voluptuous as vol -from homeassistant.components.conversation import agent_manager, default_agent +from homeassistant.components.conversation import default_agent from homeassistant.components.conversation.models import ConversationInput from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import trigger @@ -515,7 +515,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None: }, ) - agent = await agent_manager.get_agent_manager(hass).async_get_agent() + agent = default_agent.async_get_default_agent(hass) assert isinstance(agent, default_agent.DefaultAgent) result = await agent.async_process( diff --git a/tests/components/google_assistant_sdk/test_init.py b/tests/components/google_assistant_sdk/test_init.py index 7c2fc8291d4..11b3fbaa03f 100644 --- a/tests/components/google_assistant_sdk/test_init.py +++ b/tests/components/google_assistant_sdk/test_init.py @@ -327,6 +327,7 @@ async def test_conversation_agent( """Test GoogleAssistantConversationAgent.""" await setup_integration() + assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) entries = hass.config_entries.async_entries(DOMAIN) @@ -334,7 +335,7 @@ async def test_conversation_agent( entry = entries[0] assert entry.state is ConfigEntryState.LOADED - agent = await conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) + agent = conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES text1 = "tell me a joke" @@ -365,6 +366,7 @@ async def test_conversation_agent_refresh_token( """Test GoogleAssistantConversationAgent when token is expired.""" await setup_integration() + assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) entries = hass.config_entries.async_entries(DOMAIN) @@ -416,6 +418,7 @@ async def test_conversation_agent_language_changed( """Test GoogleAssistantConversationAgent when language is changed.""" await setup_integration() + assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) entries = hass.config_entries.async_entries(DOMAIN) diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 66dfd980cf3..5c979d3bc47 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -4,6 +4,8 @@ from unittest.mock import patch import pytest +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -23,10 +25,17 @@ def mock_config_entry(hass): @pytest.fixture -async def mock_init_component(hass, mock_config_entry): +async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry): """Initialize integration.""" + assert await async_setup_component(hass, "homeassistant", {}) with patch("google.generativeai.get_model"): assert await async_setup_component( hass, "google_generative_ai_conversation", {} ) await hass.async_block_till_done() + + +@pytest.fixture(autouse=True) +async def setup_ha(hass: HomeAssistant) -> None: + """Set up Home Assistant.""" + assert await async_setup_component(hass, "homeassistant", {}) diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 92e84b1fd39..befe3b93d12 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -10,6 +10,7 @@ from homeassistant.components import conversation from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import area_registry as ar, device_registry as dr, intent +from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -124,6 +125,7 @@ async def test_template_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test that template error handling works.""" + assert await async_setup_component(hass, "homeassistant", {}) hass.config_entries.async_update_entry( mock_config_entry, options={ @@ -152,7 +154,7 @@ async def test_conversation_agent( mock_init_component, ) -> None: """Test GoogleGenerativeAIAgent.""" - agent = await conversation.get_agent_manager(hass).async_get_agent( + agent = conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*" diff --git a/tests/components/mobile_app/test_webhook.py b/tests/components/mobile_app/test_webhook.py index 9d941685c09..c67312939b1 100644 --- a/tests/components/mobile_app/test_webhook.py +++ b/tests/components/mobile_app/test_webhook.py @@ -1033,7 +1033,7 @@ async def test_webhook_handle_conversation_process( webhook_client.server.app.router._frozen = False with patch( - "homeassistant.components.conversation.agent_manager.AgentManager.async_get_agent", + "homeassistant.components.conversation.agent_manager.async_get_agent", return_value=mock_conversation_agent, ): resp = await webhook_client.post( diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index 78ecf0766d7..db1689bd416 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -35,3 +35,9 @@ async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfig ): assert await async_setup_component(hass, ollama.DOMAIN, {}) await hass.async_block_till_done() + + +@pytest.fixture(autouse=True) +async def setup_ha(hass: HomeAssistant) -> None: + """Set up Home Assistant.""" + assert await async_setup_component(hass, "homeassistant", {}) diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index 6dd9dc73973..5326a8ed609 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -229,7 +229,7 @@ async def test_message_history_pruning( assert isinstance(result.conversation_id, str) conversation_ids.append(result.conversation_id) - agent = await conversation.get_agent_manager(hass).async_get_agent( + agent = conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert isinstance(agent, ollama.OllamaAgent) @@ -284,7 +284,7 @@ async def test_message_history_unlimited( result.response.response_type == intent.IntentResponseType.ACTION_DONE ), result - agent = await conversation.get_agent_manager(hass).async_get_agent( + agent = conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert isinstance(agent, ollama.OllamaAgent) @@ -340,7 +340,7 @@ async def test_conversation_agent( mock_init_component, ) -> None: """Test OllamaAgent.""" - agent = await conversation.get_agent_manager(hass).async_get_agent( + agent = conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert agent.supported_languages == MATCH_ALL diff --git a/tests/components/openai_conversation/conftest.py b/tests/components/openai_conversation/conftest.py index a8081c01c32..1597fa79d0a 100644 --- a/tests/components/openai_conversation/conftest.py +++ b/tests/components/openai_conversation/conftest.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -30,3 +31,9 @@ async def mock_init_component(hass, mock_config_entry): ): assert await async_setup_component(hass, "openai_conversation", {}) await hass.async_block_till_done() + + +@pytest.fixture(autouse=True) +async def setup_ha(hass: HomeAssistant) -> None: + """Set up Home Assistant.""" + assert await async_setup_component(hass, "homeassistant", {}) diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index c94fdcebcde..2702b749a64 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -194,7 +194,7 @@ async def test_conversation_agent( mock_init_component, ) -> None: """Test OpenAIAgent.""" - agent = await conversation.get_agent_manager(hass).async_get_agent( + agent = conversation.get_agent_manager(hass).async_get_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*"