Add conversation entity (#114518)

* Default agent as entity

* Migrate constant to point at new value

* Fix tests

* Fix more tests

* Move assist pipeline back to cloud after dependenceis
This commit is contained in:
Paulus Schoutsen 2024-04-01 21:34:25 -04:00 committed by GitHub
parent b1af590eed
commit d2e4f5f36e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 566 additions and 177 deletions

View file

@ -376,6 +376,10 @@ class Pipeline:
This function was added in HA Core 2023.10, previous versions will raise
if there are unexpected items in the serialized data.
"""
# Migrate to new value for conversation agent
if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT:
data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT
return cls(
conversation_engine=data["conversation_engine"],
conversation_language=data["conversation_language"],

View file

@ -223,7 +223,10 @@ class CloudLoginView(HomeAssistantView):
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
await cloud.login(data["email"], data["password"])
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
if "assist_pipeline" in hass.config.components:
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
else:
new_cloud_pipeline_id = None
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})

View file

@ -24,6 +24,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_when_setup
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
@ -86,9 +87,18 @@ class CloudProviderEntity(SpeechToTextEntity):
async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass."""
await async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.STT, engine_id=self.entity_id
)
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
"""When assist_pipeline is set up."""
assert self.platform.config_entry
self.platform.config_entry.async_create_task(
hass,
async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.STT, engine_id=self.entity_id
),
)
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]

View file

@ -27,6 +27,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_when_setup
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
@ -156,9 +157,19 @@ class CloudTTSEntity(TextToSpeechEntity):
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
await super().async_added_to_hass()
await async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.TTS, engine_id=self.entity_id
)
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
"""When assist_pipeline is set up."""
assert self.platform.config_entry
self.platform.config_entry.async_create_task(
hass,
async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.TTS, engine_id=self.entity_id
),
)
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
self.async_on_remove(
self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
)

View file

@ -2,7 +2,6 @@
from __future__ import annotations
from collections.abc import Iterable
import logging
import re
from typing import Literal
@ -20,6 +19,7 @@ from homeassistant.core import (
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
@ -27,15 +27,19 @@ from .agent_manager import (
AgentInfo,
agent_id_validator,
async_converse,
async_get_agent,
get_agent_manager,
)
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT
from .const import HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
from .default_agent import async_get_default_agent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"async_converse",
"async_get_agent_info",
"async_set_agent",
@ -122,16 +126,26 @@ async def async_get_conversation_languages(
all conversation agents.
"""
agent_manager = get_agent_manager(hass)
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
languages: set[str] = set()
agents: list[ConversationEntity | AbstractConversationAgent]
if agent_id:
agent = async_get_agent(hass, agent_id)
if agent is None:
raise ValueError(f"Agent {agent_id} not found")
agents = [agent]
agent_ids: Iterable[str]
if agent_id is None:
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
else:
agent_ids = (agent_id,)
agents = list(entity_component.entities)
for info in agent_manager.async_get_agent_info():
agent = agent_manager.async_get_agent(info.id)
assert agent is not None
agents.append(agent)
for _agent_id in agent_ids:
agent = await agent_manager.async_get_agent(_agent_id)
for agent in agents:
if agent.supported_languages == MATCH_ALL:
return MATCH_ALL
for language_tag in agent.supported_languages:
@ -146,10 +160,18 @@ def async_get_agent_info(
agent_id: str | None = None,
) -> AgentInfo | None:
"""Get information on the agent or None if not found."""
manager = get_agent_manager(hass)
agent = async_get_agent(hass, agent_id)
if agent_id is None:
agent_id = manager.default_agent
if agent is None:
return None
if isinstance(agent, ConversationEntity):
name = agent.name
if not isinstance(name, str):
name = agent.entity_id
return AgentInfo(id=agent.entity_id, name=name)
manager = get_agent_manager(hass)
for agent_info in manager.async_get_agent_info():
if agent_info.id == agent_id:
@ -160,10 +182,11 @@ def async_get_agent_info(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
agent_manager = get_agent_manager(hass)
entity_component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
if config_intents := config.get(DOMAIN, {}).get("intents"):
hass.data[DATA_CONFIG] = config_intents
await async_setup_default_agent(
hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
)
async def handle_process(service: ServiceCall) -> ServiceResponse:
"""Parse text into commands."""
@ -188,7 +211,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def handle_reload(service: ServiceCall) -> None:
"""Reload intents."""
agent = await agent_manager.async_get_agent()
agent = async_get_default_agent(hass)
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
hass.services.async_register(

View file

@ -2,8 +2,6 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import logging
from typing import Any
@ -11,10 +9,17 @@ import voluptuous as vol
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton
from homeassistant.helpers.entity_component import EntityComponent
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
from .default_agent import async_get_default_agent
from .entity import ConversationEntity
from .models import (
AbstractConversationAgent,
AgentInfo,
ConversationInput,
ConversationResult,
)
_LOGGER = logging.getLogger(__name__)
@ -23,20 +28,37 @@ _LOGGER = logging.getLogger(__name__)
@callback
def get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent."""
manager = AgentManager(hass)
manager.async_setup()
return manager
return AgentManager(hass)
def agent_id_validator(value: Any) -> str:
"""Validate agent ID."""
hass = async_get_hass()
manager = get_agent_manager(hass)
if not manager.async_is_valid_agent_id(cv.string(value)):
if async_get_agent(hass, cv.string(value)) is None:
raise vol.Invalid("invalid agent ID")
return value
@callback
def async_get_agent(
hass: HomeAssistant, agent_id: str | None = None
) -> AbstractConversationAgent | ConversationEntity | None:
"""Get specified agent."""
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
return async_get_default_agent(hass)
if "." in agent_id:
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
return entity_component.get_entity(agent_id)
manager = get_agent_manager(hass)
if not manager.async_is_valid_agent_id(agent_id):
return None
return manager.async_get_agent(agent_id)
async def async_converse(
hass: HomeAssistant,
text: str,
@ -47,13 +69,22 @@ async def async_converse(
device_id: str | None = None,
) -> ConversationResult:
"""Process text and get intent."""
agent = await get_agent_manager(hass).async_get_agent(agent_id)
agent = async_get_agent(hass, agent_id)
if agent is None:
raise ValueError(f"Agent {agent_id} not found")
if isinstance(agent, ConversationEntity):
agent.async_set_context(context)
method = agent.internal_async_process
else:
method = agent.async_process
if language is None:
language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
result = await agent.async_process(
result = await method(
ConversationInput(
text=text,
context=context,
@ -65,52 +96,17 @@ async def async_converse(
return result
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
class AgentManager:
"""Class to manage conversation agents."""
default_agent: str = HOME_ASSISTANT_AGENT
_builtin_agent: AbstractConversationAgent | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the conversation agents."""
self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {}
self._builtin_agent_init_lock = asyncio.Lock()
def async_setup(self) -> None:
"""Set up the conversation agents."""
async_setup_default_agent(self.hass)
async def async_get_agent(
self, agent_id: str | None = None
) -> AbstractConversationAgent:
@callback
def async_get_agent(self, agent_id: str) -> AbstractConversationAgent | None:
"""Get the agent."""
if agent_id is None:
agent_id = self.default_agent
if agent_id == HOME_ASSISTANT_AGENT:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._builtin_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent
self._builtin_agent = DefaultAgent(self.hass)
await self._builtin_agent.async_initialize(
self.hass.data.get(DATA_CONFIG)
)
return self._builtin_agent
if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found")
@ -119,12 +115,7 @@ class AgentManager:
@callback
def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents."""
agents: list[AgentInfo] = [
AgentInfo(
id=HOME_ASSISTANT_AGENT,
name="Home Assistant",
)
]
agents: list[AgentInfo] = []
for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id)
@ -148,7 +139,7 @@ class AgentManager:
@callback
def async_is_valid_agent_id(self, agent_id: str) -> bool:
"""Check if the agent id is valid."""
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT
return agent_id in self._agents
@callback
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:

View file

@ -2,5 +2,5 @@
DOMAIN = "conversation"
DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"}
HOME_ASSISTANT_AGENT = "homeassistant"
DATA_CONFIG = "conversation_config"
HOME_ASSISTANT_AGENT = "conversation.home_assistant"
OLD_HOME_ASSISTANT_AGENT = "homeassistant"

View file

@ -24,7 +24,7 @@ from hassil.util import merge_dict
from home_assistant_intents import ErrorKey, get_intents, get_languages
import yaml
from homeassistant import core, setup
from homeassistant import core
from homeassistant.components.homeassistant.exposed_entities import (
async_listen_entity_updates,
async_should_expose,
@ -40,6 +40,7 @@ from homeassistant.helpers import (
template,
translation,
)
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_added_domain,
@ -47,7 +48,8 @@ from homeassistant.helpers.event import (
from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
_LOGGER = logging.getLogger(__name__)
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
@ -60,6 +62,14 @@ TRIGGER_CALLBACK_TYPE = Callable[
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file"
DATA_DEFAULT_ENTITY = "conversation_default_entity"
@core.callback
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
"""Get the default agent."""
return hass.data[DATA_DEFAULT_ENTITY]
def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents."""
@ -109,9 +119,16 @@ def _get_language_variations(language: str) -> Iterable[str]:
yield lang
@core.callback
def async_setup(hass: core.HomeAssistant) -> None:
async def async_setup_default_agent(
hass: core.HomeAssistant,
entity_component: EntityComponent[ConversationEntity],
config_intents: dict[str, Any],
) -> None:
"""Set up entity registry listener for the default agent."""
entity = DefaultAgent(hass, config_intents)
await entity_component.async_add_entities([entity])
hass.data[DATA_DEFAULT_ENTITY] = entity
entity_registry = er.async_get(hass)
for entity_id in entity_registry.entities:
async_should_expose(hass, DOMAIN, entity_id)
@ -131,17 +148,21 @@ def async_setup(hass: core.HomeAssistant) -> None:
start.async_at_started(hass, async_hass_started)
class DefaultAgent(AbstractConversationAgent):
class DefaultAgent(ConversationEntity):
"""Default agent for conversation agent."""
def __init__(self, hass: core.HomeAssistant) -> None:
_attr_name = "Home Assistant"
def __init__(
self, hass: core.HomeAssistant, config_intents: dict[str, Any]
) -> None:
"""Initialize the default agent."""
self.hass = hass
self._lang_intents: dict[str, LanguageIntents] = {}
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# intent -> [sentences]
self._config_intents: dict[str, Any] = {}
self._config_intents: dict[str, Any] = config_intents
self._slot_lists: dict[str, SlotList] | None = None
# Sentences that will trigger a callback (skipping intent recognition)
@ -154,15 +175,6 @@ class DefaultAgent(AbstractConversationAgent):
"""Return a list of supported languages."""
return get_languages()
async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
"""Initialize the default agent."""
if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {})
# Intents from config may only contains sentences for HA config's language
if config_intents:
self._config_intents = config_intents
@core.callback
def _filter_entity_registry_changes(self, event_data: dict[str, Any]) -> bool:
"""Filter entity registry changed events."""

View file

@ -0,0 +1,57 @@
"""Entity for conversation integration."""
from abc import abstractmethod
from typing import Literal, final
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .models import ConversationInput, ConversationResult
class ConversationEntity(RestoreEntity):
"""Entity that supports conversations."""
_attr_should_poll = False
__last_activity: str | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_activity is None:
return None
return self.__last_activity
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_activity = state.state
@final
async def internal_async_process(
self, user_input: ConversationInput
) -> ConversationResult:
"""Process a sentence."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_process(user_input)
@property
@abstractmethod
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
@abstractmethod
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language."""

View file

@ -19,16 +19,24 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import language as language_util
from .agent_manager import agent_id_validator, async_converse, get_agent_manager
from .const import HOME_ASSISTANT_AGENT
from .agent_manager import (
agent_id_validator,
async_converse,
async_get_agent,
get_agent_manager,
)
from .const import DOMAIN
from .default_agent import (
METADATA_CUSTOM_FILE,
METADATA_CUSTOM_SENTENCE,
DefaultAgent,
SentenceTriggerResult,
async_get_default_agent,
)
from .entity import ConversationEntity
from .models import ConversationInput
@ -83,8 +91,14 @@ async def websocket_prepare(
msg: dict[str, Any],
) -> None:
"""Reload intents."""
manager = get_agent_manager(hass)
agent = await manager.async_get_agent(msg.get("agent_id"))
agent = async_get_agent(hass, msg.get("agent_id"))
if agent is None:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_FOUND, "Agent not found"
)
return
await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"])
@ -101,14 +115,32 @@ async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List conversation agents and, optionally, if they support a given language."""
manager = get_agent_manager(hass)
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
country = msg.get("country")
language = msg.get("language")
agents = []
for entity in entity_component.entities:
supported_languages = entity.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
agents.append(
{
"id": entity.entity_id,
"name": entity.name or entity.entity_id,
"supported_languages": supported_languages,
}
)
manager = get_agent_manager(hass)
for agent_info in manager.async_get_agent_info():
agent = await manager.async_get_agent(agent_info.id)
agent = manager.async_get_agent(agent_info.id)
assert agent is not None
supported_languages = agent.supported_languages
if language and supported_languages != MATCH_ALL:
@ -139,7 +171,7 @@ async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents that would be matched by the default agent for a list of sentences."""
agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
agent = async_get_default_agent(hass)
assert isinstance(agent, DefaultAgent)
results = [
await agent.async_recognize(

View file

@ -2,7 +2,7 @@
"domain": "conversation",
"name": "Conversation",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["http"],
"dependencies": ["http", "intent"],
"documentation": "https://www.home-assistant.io/integrations/conversation",
"integration_type": "system",
"iot_class": "local_push",

View file

@ -10,6 +10,14 @@ from homeassistant.core import Context
from homeassistant.helpers import intent
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
@dataclass(slots=True)
class ConversationInput:
"""User input to be processed."""

View file

@ -14,9 +14,8 @@ from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType
from .agent_manager import get_agent_manager
from .const import DOMAIN, HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent
from .const import DOMAIN
from .default_agent import DefaultAgent, async_get_default_agent
def has_no_punctuation(value: list[str]) -> list[str]:
@ -111,7 +110,7 @@ async def async_attach_trigger(
# two trigger copies for who will provide a response.
return None
default_agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
default_agent = async_get_default_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return default_agent.register_trigger(sentences, call_action)

View file

@ -34,7 +34,7 @@
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en',
}),
@ -123,7 +123,7 @@
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en-US',
}),
@ -212,7 +212,7 @@
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en-US',
}),
@ -325,7 +325,7 @@
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en',
}),

View file

@ -33,7 +33,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en',
})
@ -114,7 +114,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en',
})
@ -207,7 +207,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en',
})
@ -409,7 +409,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
'language': 'en',
})
@ -615,7 +615,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
'language': 'en',
})
@ -637,7 +637,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
'language': 'en',
})
@ -665,7 +665,7 @@
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'never mind',
'language': 'en',
})
@ -799,7 +799,7 @@
dict({
'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id',
'engine': 'homeassistant',
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
'language': 'en',
})

View file

@ -6,6 +6,7 @@ from unittest.mock import ANY, patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
@ -117,6 +118,7 @@ async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored pipelines on start."""
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
hass_storage[STORAGE_KEY] = {
"version": STORAGE_VERSION,
"minor_version": STORAGE_VERSION_MINOR,
@ -124,9 +126,9 @@ async def test_loading_pipelines_from_storage(
"data": {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_engine": conversation.OLD_HOME_ASSISTANT_AGENT,
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"id": id_1,
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
@ -166,7 +168,7 @@ async def test_loading_pipelines_from_storage(
"wake_word_id": "wakeword_id_3",
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"preferred_item": id_1,
},
}
@ -175,7 +177,8 @@ async def test_loading_pipelines_from_storage(
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
assert store.async_get_preferred_item() == id_1
assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT
async def test_migrate_pipeline_store(
@ -262,7 +265,7 @@ async def test_create_default_pipeline(
tts_engine_id="test",
pipeline_name="Test pipeline",
) == Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=ANY,
language="en",
@ -304,7 +307,7 @@ async def test_get_pipelines(hass: HomeAssistant) -> None:
pipelines = async_get_pipelines(hass)
assert list(pipelines) == [
Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=ANY,
language="en",
@ -351,7 +354,7 @@ async def test_default_pipeline_no_stt_tts(
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
@ -414,7 +417,7 @@ async def test_default_pipeline(
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
@ -445,7 +448,7 @@ async def test_default_pipeline_unsupported_stt_language(
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=pipeline.id,
language="en",
@ -476,7 +479,7 @@ async def test_default_pipeline_unsupported_tts_language(
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=pipeline.id,
language="en",
@ -502,7 +505,7 @@ async def test_update_pipeline(
pipelines = list(pipelines)
assert pipelines == [
Pipeline(
conversation_engine="homeassistant",
conversation_engine="conversation.home_assistant",
conversation_language="en",
id=ANY,
language="en",

View file

@ -1166,7 +1166,7 @@ async def test_get_pipeline(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "homeassistant",
"conversation_engine": "conversation.home_assistant",
"conversation_language": "en",
"id": ANY,
"language": "en",
@ -1250,7 +1250,7 @@ async def test_list_pipelines(
assert msg["result"] == {
"pipelines": [
{
"conversation_engine": "homeassistant",
"conversation_engine": "conversation.home_assistant",
"conversation_language": "en",
"id": ANY,
"language": "en",
@ -2012,7 +2012,7 @@ async def test_wake_word_cooldown_different_entities(
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_engine": "conversation.home_assistant",
"conversation_language": "en-US",
"language": "en",
"name": "pipeline_with_wake_word_1",
@ -2032,7 +2032,7 @@ async def test_wake_word_cooldown_different_entities(
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_engine": "conversation.home_assistant",
"conversation_language": "en-US",
"language": "en",
"name": "pipeline_with_wake_word_2",

View file

@ -7,10 +7,12 @@ from homeassistant.components.cloud.assist_pipeline import (
)
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None:
"""Test migrate pipeline with invalid platform."""
await async_setup_component(hass, "assist_pipeline", {})
with pytest.raises(ValueError):
await async_migrate_cloud_pipeline_engine(
hass, Platform.BINARY_SENSOR, "test-engine-id"

View file

@ -231,6 +231,7 @@ async def test_login_view_create_pipeline(
}
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
@ -270,6 +271,7 @@ async def test_login_view_create_pipeline_fail(
}
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()

View file

@ -7,6 +7,8 @@ import pytest
from homeassistant.components import conversation
from homeassistant.components.shopping_list import intent as sl_intent
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import MockAgent
@ -14,7 +16,7 @@ from tests.common import MockConfigEntry
@pytest.fixture
def mock_agent_support_all(hass):
def mock_agent_support_all(hass: HomeAssistant):
"""Mock agent that supports all languages."""
entry = MockConfigEntry(entry_id="mock-entry-support-all")
entry.add_to_hass(hass)
@ -34,7 +36,7 @@ def mock_shopping_list_io():
@pytest.fixture
async def sl_setup(hass):
async def sl_setup(hass: HomeAssistant):
"""Set up the shopping list."""
entry = MockConfigEntry(domain="shopping_list")
@ -43,3 +45,10 @@ async def sl_setup(hass):
assert await hass.config_entries.async_setup(entry.entry_id)
await sl_intent.async_setup_intents(hass)
@pytest.fixture
async def init_components(hass: HomeAssistant):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})

View file

@ -101,7 +101,7 @@
# ---
# name: test_get_agent_info
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
})
# ---
@ -113,7 +113,7 @@
# ---
# name: test_get_agent_info.2
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
})
# ---
@ -127,7 +127,7 @@
dict({
'agents': list([
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supported_languages': list([
'af',
@ -207,7 +207,7 @@
dict({
'agents': list([
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supported_languages': list([
]),
@ -231,7 +231,7 @@
dict({
'agents': list([
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supported_languages': list([
'en',
@ -255,7 +255,7 @@
dict({
'agents': list([
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supported_languages': list([
'en',
@ -279,7 +279,7 @@
dict({
'agents': list([
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supported_languages': list([
'de',
@ -304,7 +304,7 @@
dict({
'agents': list([
dict({
'id': 'homeassistant',
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supported_languages': list([
'de-CH',
@ -415,6 +415,36 @@
}),
})
# ---
# name: test_http_processing_intent[conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': 'entity',
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_http_processing_intent[homeassistant]
dict({
'conversation_id': None,
@ -1035,6 +1065,36 @@
}),
})
# ---
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[None-turn kitchen on-homeassistant]
dict({
'conversation_id': None,
@ -1095,6 +1155,36 @@
}),
})
# ---
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[None-turn on kitchen-homeassistant]
dict({
'conversation_id': None,
@ -1155,6 +1245,36 @@
}),
})
# ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
dict({
'conversation_id': None,
@ -1215,6 +1335,36 @@
}),
})
# ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
dict({
'id': 'light.kitchen',
'name': 'kitchen',
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
}),
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Turned on the light',
}),
}),
}),
})
# ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
dict({
'conversation_id': None,

View file

@ -7,7 +7,7 @@ from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent
from homeassistant.components.conversation import default_agent
from homeassistant.components.homeassistant.exposed_entities import (
async_get_assistant_settings,
)
@ -152,9 +152,7 @@ async def test_conversation_agent(
init_components,
) -> None:
"""Test DefaultAgent."""
agent = await agent_manager.get_agent_manager(hass).async_get_agent(
conversation.HOME_ASSISTANT_AGENT
)
agent = default_agent.async_get_default_agent(hass)
with patch(
"homeassistant.components.conversation.default_agent.get_languages",
return_value=["dwarvish", "elvish", "entish"],
@ -181,6 +179,7 @@ async def test_expose_flag_automatically_set(
# After setting up conversation, the expose flag should now be set on all entities
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
"conversation.home_assistant": {"should_expose": False},
light.entity_id: {"should_expose": True},
test.entity_id: {"should_expose": False},
}
@ -190,6 +189,7 @@ async def test_expose_flag_automatically_set(
hass.states.async_set(new_light, "test")
await hass.async_block_till_done()
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
"conversation.home_assistant": {"should_expose": False},
light.entity_id: {"should_expose": True},
new_light: {"should_expose": True},
test.entity_id: {"should_expose": False},
@ -254,9 +254,7 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!"
agent = await agent_manager.get_agent_manager(hass).async_get_agent(
conversation.HOME_ASSISTANT_AGENT
)
agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent)
callback = AsyncMock(return_value=trigger_response)

View file

@ -0,0 +1,47 @@
"""Tests for conversation entity."""
from unittest.mock import patch
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from tests.common import mock_restore_cache
async def test_state_set_and_restore(hass: HomeAssistant) -> None:
"""Test we set and restore state in the integration."""
entity_id = "conversation.home_assistant"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
await async_setup_component(hass, "homeassistant", {})
await async_setup_component(hass, "conversation", {})
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp
now = dt_util.utcnow()
context = Context()
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process"
) as mock_process,
patch("homeassistant.util.dt.utcnow", return_value=now),
):
await hass.services.async_call(
"conversation",
"process",
{"text": "Hello"},
context=context,
blocking=True,
)
assert len(mock_process.mock_calls) == 1
state = hass.states.get(entity_id)
assert state
assert state.state == now.isoformat()
assert state.context is context

View file

@ -9,7 +9,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import agent_manager, default_agent
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.models import ConversationInput
from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
@ -35,7 +35,13 @@ from tests.common import (
from tests.components.light.common import MockLight
from tests.typing import ClientSessionGenerator, WebSocketGenerator
AGENT_ID_OPTIONS = [None, conversation.HOME_ASSISTANT_AGENT]
AGENT_ID_OPTIONS = [
None,
# Old value of conversation.HOME_ASSISTANT_AGENT,
"homeassistant",
# Current value of conversation.HOME_ASSISTANT_AGENT,
"conversation.home_assistant",
]
class OrderBeerIntentHandler(intent.IntentHandler):
@ -51,14 +57,6 @@ class OrderBeerIntentHandler(intent.IntentHandler):
return response
@pytest.fixture
async def init_components(hass):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
async def test_http_processing_intent(
hass: HomeAssistant,
@ -752,7 +750,7 @@ async def test_ws_prepare(
"""Test the Websocket prepare conversation API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent)
# No intents should be loaded yet
@ -854,7 +852,7 @@ async def test_prepare_reload(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare(language)
@ -882,7 +880,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, "conversation", {})
# Load intents
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent)
await agent.async_prepare("not-a-language")
@ -919,7 +917,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
hass.states.async_set("cover.front_door", "closed")
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process(
@ -1063,12 +1061,15 @@ async def test_light_area_same_name(
assert call.data == {"entity_id": [kitchen_light.entity_id]}
async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
async def test_agent_id_validator_invalid_agent(
hass: HomeAssistant, init_components
) -> None:
"""Test validating agent id."""
with pytest.raises(vol.Invalid):
conversation.agent_id_validator("invalid_agent")
conversation.agent_id_validator(conversation.HOME_ASSISTANT_AGENT)
conversation.agent_id_validator("conversation.home_assistant")
async def test_get_agent_list(

View file

@ -5,7 +5,7 @@ import logging
import pytest
import voluptuous as vol
from homeassistant.components.conversation import agent_manager, default_agent
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation.models import ConversationInput
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import trigger
@ -515,7 +515,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
},
)
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
agent = default_agent.async_get_default_agent(hass)
assert isinstance(agent, default_agent.DefaultAgent)
result = await agent.async_process(

View file

@ -327,6 +327,7 @@ async def test_conversation_agent(
"""Test GoogleAssistantConversationAgent."""
await setup_integration()
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)
@ -334,7 +335,7 @@ async def test_conversation_agent(
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
agent = await conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
agent = conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
text1 = "tell me a joke"
@ -365,6 +366,7 @@ async def test_conversation_agent_refresh_token(
"""Test GoogleAssistantConversationAgent when token is expired."""
await setup_integration()
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)
@ -416,6 +418,7 @@ async def test_conversation_agent_language_changed(
"""Test GoogleAssistantConversationAgent when language is changed."""
await setup_integration()
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)

View file

@ -4,6 +4,8 @@ from unittest.mock import patch
import pytest
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -23,10 +25,17 @@ def mock_config_entry(hass):
@pytest.fixture
async def mock_init_component(hass, mock_config_entry):
async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry):
"""Initialize integration."""
assert await async_setup_component(hass, "homeassistant", {})
with patch("google.generativeai.get_model"):
assert await async_setup_component(
hass, "google_generative_ai_conversation", {}
)
await hass.async_block_till_done()
@pytest.fixture(autouse=True)
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})

View file

@ -10,6 +10,7 @@ from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -124,6 +125,7 @@ async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
assert await async_setup_component(hass, "homeassistant", {})
hass.config_entries.async_update_entry(
mock_config_entry,
options={
@ -152,7 +154,7 @@ async def test_conversation_agent(
mock_init_component,
) -> None:
"""Test GoogleGenerativeAIAgent."""
agent = await conversation.get_agent_manager(hass).async_get_agent(
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"

View file

@ -1033,7 +1033,7 @@ async def test_webhook_handle_conversation_process(
webhook_client.server.app.router._frozen = False
with patch(
"homeassistant.components.conversation.agent_manager.AgentManager.async_get_agent",
"homeassistant.components.conversation.agent_manager.async_get_agent",
return_value=mock_conversation_agent,
):
resp = await webhook_client.post(

View file

@ -35,3 +35,9 @@ async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfig
):
assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done()
@pytest.fixture(autouse=True)
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})

View file

@ -229,7 +229,7 @@ async def test_message_history_pruning(
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = await conversation.get_agent_manager(hass).async_get_agent(
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
@ -284,7 +284,7 @@ async def test_message_history_unlimited(
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
agent = await conversation.get_agent_manager(hass).async_get_agent(
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
@ -340,7 +340,7 @@ async def test_conversation_agent(
mock_init_component,
) -> None:
"""Test OllamaAgent."""
agent = await conversation.get_agent_manager(hass).async_get_agent(
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == MATCH_ALL

View file

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -30,3 +31,9 @@ async def mock_init_component(hass, mock_config_entry):
):
assert await async_setup_component(hass, "openai_conversation", {})
await hass.async_block_till_done()
@pytest.fixture(autouse=True)
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})

View file

@ -194,7 +194,7 @@ async def test_conversation_agent(
mock_init_component,
) -> None:
"""Test OpenAIAgent."""
agent = await conversation.get_agent_manager(hass).async_get_agent(
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"