Add conversation entity (#114518)
* Default agent as entity * Migrate constant to point at new value * Fix tests * Fix more tests * Move assist pipeline back to cloud after dependenceis
This commit is contained in:
parent
b1af590eed
commit
d2e4f5f36e
33 changed files with 566 additions and 177 deletions
|
@ -376,6 +376,10 @@ class Pipeline:
|
|||
This function was added in HA Core 2023.10, previous versions will raise
|
||||
if there are unexpected items in the serialized data.
|
||||
"""
|
||||
# Migrate to new value for conversation agent
|
||||
if data["conversation_engine"] == conversation.OLD_HOME_ASSISTANT_AGENT:
|
||||
data["conversation_engine"] = conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
return cls(
|
||||
conversation_engine=data["conversation_engine"],
|
||||
conversation_language=data["conversation_language"],
|
||||
|
|
|
@ -223,7 +223,10 @@ class CloudLoginView(HomeAssistantView):
|
|||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
await cloud.login(data["email"], data["password"])
|
||||
|
||||
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
|
||||
if "assist_pipeline" in hass.config.components:
|
||||
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
|
||||
else:
|
||||
new_cloud_pipeline_id = None
|
||||
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_when_setup
|
||||
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||
from .client import CloudClient
|
||||
|
@ -86,9 +87,18 @@ class CloudProviderEntity(SpeechToTextEntity):
|
|||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity is about to be added to hass."""
|
||||
await async_migrate_cloud_pipeline_engine(
|
||||
self.hass, platform=Platform.STT, engine_id=self.entity_id
|
||||
)
|
||||
|
||||
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
|
||||
"""When assist_pipeline is set up."""
|
||||
assert self.platform.config_entry
|
||||
self.platform.config_entry.async_create_task(
|
||||
hass,
|
||||
async_migrate_cloud_pipeline_engine(
|
||||
self.hass, platform=Platform.STT, engine_id=self.entity_id
|
||||
),
|
||||
)
|
||||
|
||||
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
|
|
|
@ -27,6 +27,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_when_setup
|
||||
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||
from .client import CloudClient
|
||||
|
@ -156,9 +157,19 @@ class CloudTTSEntity(TextToSpeechEntity):
|
|||
async def async_added_to_hass(self) -> None:
|
||||
"""Handle entity which will be added."""
|
||||
await super().async_added_to_hass()
|
||||
await async_migrate_cloud_pipeline_engine(
|
||||
self.hass, platform=Platform.TTS, engine_id=self.entity_id
|
||||
)
|
||||
|
||||
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
|
||||
"""When assist_pipeline is set up."""
|
||||
assert self.platform.config_entry
|
||||
self.platform.config_entry.async_create_task(
|
||||
hass,
|
||||
async_migrate_cloud_pipeline_engine(
|
||||
self.hass, platform=Platform.TTS, engine_id=self.entity_id
|
||||
),
|
||||
)
|
||||
|
||||
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
|
||||
|
||||
self.async_on_remove(
|
||||
self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
|
||||
)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
@ -20,6 +19,7 @@ from homeassistant.core import (
|
|||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
|
@ -27,15 +27,19 @@ from .agent_manager import (
|
|||
AgentInfo,
|
||||
agent_id_validator,
|
||||
async_converse,
|
||||
async_get_agent,
|
||||
get_agent_manager,
|
||||
)
|
||||
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT
|
||||
from .const import HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
|
||||
from .default_agent import async_get_default_agent, async_setup_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
"async_set_agent",
|
||||
|
@ -122,16 +126,26 @@ async def async_get_conversation_languages(
|
|||
all conversation agents.
|
||||
"""
|
||||
agent_manager = get_agent_manager(hass)
|
||||
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
|
||||
languages: set[str] = set()
|
||||
agents: list[ConversationEntity | AbstractConversationAgent]
|
||||
|
||||
if agent_id:
|
||||
agent = async_get_agent(hass, agent_id)
|
||||
|
||||
if agent is None:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
agents = [agent]
|
||||
|
||||
agent_ids: Iterable[str]
|
||||
if agent_id is None:
|
||||
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
|
||||
else:
|
||||
agent_ids = (agent_id,)
|
||||
agents = list(entity_component.entities)
|
||||
for info in agent_manager.async_get_agent_info():
|
||||
agent = agent_manager.async_get_agent(info.id)
|
||||
assert agent is not None
|
||||
agents.append(agent)
|
||||
|
||||
for _agent_id in agent_ids:
|
||||
agent = await agent_manager.async_get_agent(_agent_id)
|
||||
for agent in agents:
|
||||
if agent.supported_languages == MATCH_ALL:
|
||||
return MATCH_ALL
|
||||
for language_tag in agent.supported_languages:
|
||||
|
@ -146,10 +160,18 @@ def async_get_agent_info(
|
|||
agent_id: str | None = None,
|
||||
) -> AgentInfo | None:
|
||||
"""Get information on the agent or None if not found."""
|
||||
manager = get_agent_manager(hass)
|
||||
agent = async_get_agent(hass, agent_id)
|
||||
|
||||
if agent_id is None:
|
||||
agent_id = manager.default_agent
|
||||
if agent is None:
|
||||
return None
|
||||
|
||||
if isinstance(agent, ConversationEntity):
|
||||
name = agent.name
|
||||
if not isinstance(name, str):
|
||||
name = agent.entity_id
|
||||
return AgentInfo(id=agent.entity_id, name=name)
|
||||
|
||||
manager = get_agent_manager(hass)
|
||||
|
||||
for agent_info in manager.async_get_agent_info():
|
||||
if agent_info.id == agent_id:
|
||||
|
@ -160,10 +182,11 @@ def async_get_agent_info(
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
agent_manager = get_agent_manager(hass)
|
||||
entity_component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
|
||||
if config_intents := config.get(DOMAIN, {}).get("intents"):
|
||||
hass.data[DATA_CONFIG] = config_intents
|
||||
await async_setup_default_agent(
|
||||
hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
|
||||
)
|
||||
|
||||
async def handle_process(service: ServiceCall) -> ServiceResponse:
|
||||
"""Parse text into commands."""
|
||||
|
@ -188,7 +211,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
async def handle_reload(service: ServiceCall) -> None:
|
||||
"""Reload intents."""
|
||||
agent = await agent_manager.async_get_agent()
|
||||
agent = async_get_default_agent(hass)
|
||||
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
|
||||
|
||||
hass.services.async_register(
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
@ -11,10 +9,17 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
|
||||
from homeassistant.helpers import config_validation as cv, singleton
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DATA_CONFIG, HOME_ASSISTANT_AGENT
|
||||
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .const import DOMAIN, HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT
|
||||
from .default_agent import async_get_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .models import (
|
||||
AbstractConversationAgent,
|
||||
AgentInfo,
|
||||
ConversationInput,
|
||||
ConversationResult,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -23,20 +28,37 @@ _LOGGER = logging.getLogger(__name__)
|
|||
@callback
|
||||
def get_agent_manager(hass: HomeAssistant) -> AgentManager:
|
||||
"""Get the active agent."""
|
||||
manager = AgentManager(hass)
|
||||
manager.async_setup()
|
||||
return manager
|
||||
return AgentManager(hass)
|
||||
|
||||
|
||||
def agent_id_validator(value: Any) -> str:
|
||||
"""Validate agent ID."""
|
||||
hass = async_get_hass()
|
||||
manager = get_agent_manager(hass)
|
||||
if not manager.async_is_valid_agent_id(cv.string(value)):
|
||||
if async_get_agent(hass, cv.string(value)) is None:
|
||||
raise vol.Invalid("invalid agent ID")
|
||||
return value
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_agent(
|
||||
hass: HomeAssistant, agent_id: str | None = None
|
||||
) -> AbstractConversationAgent | ConversationEntity | None:
|
||||
"""Get specified agent."""
|
||||
if agent_id is None or agent_id in (HOME_ASSISTANT_AGENT, OLD_HOME_ASSISTANT_AGENT):
|
||||
return async_get_default_agent(hass)
|
||||
|
||||
if "." in agent_id:
|
||||
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
|
||||
return entity_component.get_entity(agent_id)
|
||||
|
||||
manager = get_agent_manager(hass)
|
||||
|
||||
if not manager.async_is_valid_agent_id(agent_id):
|
||||
return None
|
||||
|
||||
return manager.async_get_agent(agent_id)
|
||||
|
||||
|
||||
async def async_converse(
|
||||
hass: HomeAssistant,
|
||||
text: str,
|
||||
|
@ -47,13 +69,22 @@ async def async_converse(
|
|||
device_id: str | None = None,
|
||||
) -> ConversationResult:
|
||||
"""Process text and get intent."""
|
||||
agent = await get_agent_manager(hass).async_get_agent(agent_id)
|
||||
agent = async_get_agent(hass, agent_id)
|
||||
|
||||
if agent is None:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
if isinstance(agent, ConversationEntity):
|
||||
agent.async_set_context(context)
|
||||
method = agent.internal_async_process
|
||||
else:
|
||||
method = agent.async_process
|
||||
|
||||
if language is None:
|
||||
language = hass.config.language
|
||||
|
||||
_LOGGER.debug("Processing in %s: %s", language, text)
|
||||
result = await agent.async_process(
|
||||
result = await method(
|
||||
ConversationInput(
|
||||
text=text,
|
||||
context=context,
|
||||
|
@ -65,52 +96,17 @@ async def async_converse(
|
|||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInfo:
|
||||
"""Container for conversation agent info."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""Class to manage conversation agents."""
|
||||
|
||||
default_agent: str = HOME_ASSISTANT_AGENT
|
||||
_builtin_agent: AbstractConversationAgent | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the conversation agents."""
|
||||
self.hass = hass
|
||||
self._agents: dict[str, AbstractConversationAgent] = {}
|
||||
self._builtin_agent_init_lock = asyncio.Lock()
|
||||
|
||||
def async_setup(self) -> None:
|
||||
"""Set up the conversation agents."""
|
||||
async_setup_default_agent(self.hass)
|
||||
|
||||
async def async_get_agent(
|
||||
self, agent_id: str | None = None
|
||||
) -> AbstractConversationAgent:
|
||||
@callback
|
||||
def async_get_agent(self, agent_id: str) -> AbstractConversationAgent | None:
|
||||
"""Get the agent."""
|
||||
if agent_id is None:
|
||||
agent_id = self.default_agent
|
||||
|
||||
if agent_id == HOME_ASSISTANT_AGENT:
|
||||
if self._builtin_agent is not None:
|
||||
return self._builtin_agent
|
||||
|
||||
async with self._builtin_agent_init_lock:
|
||||
if self._builtin_agent is not None:
|
||||
return self._builtin_agent
|
||||
|
||||
self._builtin_agent = DefaultAgent(self.hass)
|
||||
await self._builtin_agent.async_initialize(
|
||||
self.hass.data.get(DATA_CONFIG)
|
||||
)
|
||||
|
||||
return self._builtin_agent
|
||||
|
||||
if agent_id not in self._agents:
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
|
@ -119,12 +115,7 @@ class AgentManager:
|
|||
@callback
|
||||
def async_get_agent_info(self) -> list[AgentInfo]:
|
||||
"""List all agents."""
|
||||
agents: list[AgentInfo] = [
|
||||
AgentInfo(
|
||||
id=HOME_ASSISTANT_AGENT,
|
||||
name="Home Assistant",
|
||||
)
|
||||
]
|
||||
agents: list[AgentInfo] = []
|
||||
for agent_id, agent in self._agents.items():
|
||||
config_entry = self.hass.config_entries.async_get_entry(agent_id)
|
||||
|
||||
|
@ -148,7 +139,7 @@ class AgentManager:
|
|||
@callback
|
||||
def async_is_valid_agent_id(self, agent_id: str) -> bool:
|
||||
"""Check if the agent id is valid."""
|
||||
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT
|
||||
return agent_id in self._agents
|
||||
|
||||
@callback
|
||||
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
||||
|
|
|
@ -2,5 +2,5 @@
|
|||
|
||||
DOMAIN = "conversation"
|
||||
DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"}
|
||||
HOME_ASSISTANT_AGENT = "homeassistant"
|
||||
DATA_CONFIG = "conversation_config"
|
||||
HOME_ASSISTANT_AGENT = "conversation.home_assistant"
|
||||
OLD_HOME_ASSISTANT_AGENT = "homeassistant"
|
||||
|
|
|
@ -24,7 +24,7 @@ from hassil.util import merge_dict
|
|||
from home_assistant_intents import ErrorKey, get_intents, get_languages
|
||||
import yaml
|
||||
|
||||
from homeassistant import core, setup
|
||||
from homeassistant import core
|
||||
from homeassistant.components.homeassistant.exposed_entities import (
|
||||
async_listen_entity_updates,
|
||||
async_should_expose,
|
||||
|
@ -40,6 +40,7 @@ from homeassistant.helpers import (
|
|||
template,
|
||||
translation,
|
||||
)
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import (
|
||||
EventStateChangedData,
|
||||
async_track_state_added_domain,
|
||||
|
@ -47,7 +48,8 @@ from homeassistant.helpers.event import (
|
|||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
|
||||
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
||||
|
@ -60,6 +62,14 @@ TRIGGER_CALLBACK_TYPE = Callable[
|
|||
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
||||
METADATA_CUSTOM_FILE = "hass_custom_file"
|
||||
|
||||
DATA_DEFAULT_ENTITY = "conversation_default_entity"
|
||||
|
||||
|
||||
@core.callback
|
||||
def async_get_default_agent(hass: core.HomeAssistant) -> DefaultAgent:
|
||||
"""Get the default agent."""
|
||||
return hass.data[DATA_DEFAULT_ENTITY]
|
||||
|
||||
|
||||
def json_load(fp: IO[str]) -> JsonObjectType:
|
||||
"""Wrap json_loads for get_intents."""
|
||||
|
@ -109,9 +119,16 @@ def _get_language_variations(language: str) -> Iterable[str]:
|
|||
yield lang
|
||||
|
||||
|
||||
@core.callback
|
||||
def async_setup(hass: core.HomeAssistant) -> None:
|
||||
async def async_setup_default_agent(
|
||||
hass: core.HomeAssistant,
|
||||
entity_component: EntityComponent[ConversationEntity],
|
||||
config_intents: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set up entity registry listener for the default agent."""
|
||||
entity = DefaultAgent(hass, config_intents)
|
||||
await entity_component.async_add_entities([entity])
|
||||
hass.data[DATA_DEFAULT_ENTITY] = entity
|
||||
|
||||
entity_registry = er.async_get(hass)
|
||||
for entity_id in entity_registry.entities:
|
||||
async_should_expose(hass, DOMAIN, entity_id)
|
||||
|
@ -131,17 +148,21 @@ def async_setup(hass: core.HomeAssistant) -> None:
|
|||
start.async_at_started(hass, async_hass_started)
|
||||
|
||||
|
||||
class DefaultAgent(AbstractConversationAgent):
|
||||
class DefaultAgent(ConversationEntity):
|
||||
"""Default agent for conversation agent."""
|
||||
|
||||
def __init__(self, hass: core.HomeAssistant) -> None:
|
||||
_attr_name = "Home Assistant"
|
||||
|
||||
def __init__(
|
||||
self, hass: core.HomeAssistant, config_intents: dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize the default agent."""
|
||||
self.hass = hass
|
||||
self._lang_intents: dict[str, LanguageIntents] = {}
|
||||
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
# intent -> [sentences]
|
||||
self._config_intents: dict[str, Any] = {}
|
||||
self._config_intents: dict[str, Any] = config_intents
|
||||
self._slot_lists: dict[str, SlotList] | None = None
|
||||
|
||||
# Sentences that will trigger a callback (skipping intent recognition)
|
||||
|
@ -154,15 +175,6 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
"""Return a list of supported languages."""
|
||||
return get_languages()
|
||||
|
||||
async def async_initialize(self, config_intents: dict[str, Any] | None) -> None:
|
||||
"""Initialize the default agent."""
|
||||
if "intent" not in self.hass.config.components:
|
||||
await setup.async_setup_component(self.hass, "intent", {})
|
||||
|
||||
# Intents from config may only contains sentences for HA config's language
|
||||
if config_intents:
|
||||
self._config_intents = config_intents
|
||||
|
||||
@core.callback
|
||||
def _filter_entity_registry_changes(self, event_data: dict[str, Any]) -> bool:
|
||||
"""Filter entity registry changed events."""
|
||||
|
|
57
homeassistant/components/conversation/entity.py
Normal file
57
homeassistant/components/conversation/entity.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
"""Entity for conversation integration."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Literal, final
|
||||
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
|
||||
class ConversationEntity(RestoreEntity):
|
||||
"""Entity that supports conversations."""
|
||||
|
||||
_attr_should_poll = False
|
||||
__last_activity: str | None = None
|
||||
|
||||
@property
|
||||
@final
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the entity."""
|
||||
if self.__last_activity is None:
|
||||
return None
|
||||
return self.__last_activity
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the entity is added to hass."""
|
||||
await super().async_internal_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if (
|
||||
state is not None
|
||||
and state.state is not None
|
||||
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
):
|
||||
self.__last_activity = state.state
|
||||
|
||||
@final
|
||||
async def internal_async_process(
|
||||
self, user_input: ConversationInput
|
||||
) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
self.__last_activity = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self.async_process(user_input)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
|
||||
async def async_prepare(self, language: str | None = None) -> None:
|
||||
"""Load intents for a language."""
|
|
@ -19,16 +19,24 @@ from homeassistant.components.http.data_validator import RequestDataValidator
|
|||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant, State, callback
|
||||
from homeassistant.helpers import config_validation as cv, intent
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .agent_manager import agent_id_validator, async_converse, get_agent_manager
|
||||
from .const import HOME_ASSISTANT_AGENT
|
||||
from .agent_manager import (
|
||||
agent_id_validator,
|
||||
async_converse,
|
||||
async_get_agent,
|
||||
get_agent_manager,
|
||||
)
|
||||
from .const import DOMAIN
|
||||
from .default_agent import (
|
||||
METADATA_CUSTOM_FILE,
|
||||
METADATA_CUSTOM_SENTENCE,
|
||||
DefaultAgent,
|
||||
SentenceTriggerResult,
|
||||
async_get_default_agent,
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput
|
||||
|
||||
|
||||
|
@ -83,8 +91,14 @@ async def websocket_prepare(
|
|||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Reload intents."""
|
||||
manager = get_agent_manager(hass)
|
||||
agent = await manager.async_get_agent(msg.get("agent_id"))
|
||||
agent = async_get_agent(hass, msg.get("agent_id"))
|
||||
|
||||
if agent is None:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_NOT_FOUND, "Agent not found"
|
||||
)
|
||||
return
|
||||
|
||||
await agent.async_prepare(msg.get("language"))
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
@ -101,14 +115,32 @@ async def websocket_list_agents(
|
|||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List conversation agents and, optionally, if they support a given language."""
|
||||
manager = get_agent_manager(hass)
|
||||
entity_component: EntityComponent[ConversationEntity] = hass.data[DOMAIN]
|
||||
|
||||
country = msg.get("country")
|
||||
language = msg.get("language")
|
||||
agents = []
|
||||
|
||||
for entity in entity_component.entities:
|
||||
supported_languages = entity.supported_languages
|
||||
if language and supported_languages != MATCH_ALL:
|
||||
supported_languages = language_util.matches(
|
||||
language, supported_languages, country
|
||||
)
|
||||
|
||||
agents.append(
|
||||
{
|
||||
"id": entity.entity_id,
|
||||
"name": entity.name or entity.entity_id,
|
||||
"supported_languages": supported_languages,
|
||||
}
|
||||
)
|
||||
|
||||
manager = get_agent_manager(hass)
|
||||
|
||||
for agent_info in manager.async_get_agent_info():
|
||||
agent = await manager.async_get_agent(agent_info.id)
|
||||
agent = manager.async_get_agent(agent_info.id)
|
||||
assert agent is not None
|
||||
|
||||
supported_languages = agent.supported_languages
|
||||
if language and supported_languages != MATCH_ALL:
|
||||
|
@ -139,7 +171,7 @@ async def websocket_hass_agent_debug(
|
|||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Return intents that would be matched by the default agent for a list of sentences."""
|
||||
agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
|
||||
agent = async_get_default_agent(hass)
|
||||
assert isinstance(agent, DefaultAgent)
|
||||
results = [
|
||||
await agent.async_recognize(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"domain": "conversation",
|
||||
"name": "Conversation",
|
||||
"codeowners": ["@home-assistant/core", "@synesthesiam"],
|
||||
"dependencies": ["http"],
|
||||
"dependencies": ["http", "intent"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/conversation",
|
||||
"integration_type": "system",
|
||||
"iot_class": "local_push",
|
||||
|
|
|
@ -10,6 +10,14 @@ from homeassistant.core import Context
|
|||
from homeassistant.helpers import intent
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInfo:
|
||||
"""Container for conversation agent info."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversationInput:
|
||||
"""User input to be processed."""
|
||||
|
|
|
@ -14,9 +14,8 @@ from homeassistant.helpers.script import ScriptRunResult
|
|||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
||||
|
||||
from .agent_manager import get_agent_manager
|
||||
from .const import DOMAIN, HOME_ASSISTANT_AGENT
|
||||
from .default_agent import DefaultAgent
|
||||
from .const import DOMAIN
|
||||
from .default_agent import DefaultAgent, async_get_default_agent
|
||||
|
||||
|
||||
def has_no_punctuation(value: list[str]) -> list[str]:
|
||||
|
@ -111,7 +110,7 @@ async def async_attach_trigger(
|
|||
# two trigger copies for who will provide a response.
|
||||
return None
|
||||
|
||||
default_agent = await get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
|
||||
default_agent = async_get_default_agent(hass)
|
||||
assert isinstance(default_agent, DefaultAgent)
|
||||
|
||||
return default_agent.register_trigger(sentences, call_action)
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
}),
|
||||
|
@ -123,7 +123,7 @@
|
|||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en-US',
|
||||
}),
|
||||
|
@ -212,7 +212,7 @@
|
|||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en-US',
|
||||
}),
|
||||
|
@ -325,7 +325,7 @@
|
|||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
}),
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -114,7 +114,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -207,7 +207,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -409,7 +409,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -615,7 +615,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -637,7 +637,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -665,7 +665,7 @@
|
|||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'never mind',
|
||||
'language': 'en',
|
||||
})
|
||||
|
@ -799,7 +799,7 @@
|
|||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': 'mock-device-id',
|
||||
'engine': 'homeassistant',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': 'en',
|
||||
})
|
||||
|
|
|
@ -6,6 +6,7 @@ from unittest.mock import ANY, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
STORAGE_KEY,
|
||||
|
@ -117,6 +118,7 @@ async def test_loading_pipelines_from_storage(
|
|||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test loading stored pipelines on start."""
|
||||
id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": STORAGE_VERSION,
|
||||
"minor_version": STORAGE_VERSION_MINOR,
|
||||
|
@ -124,9 +126,9 @@ async def test_loading_pipelines_from_storage(
|
|||
"data": {
|
||||
"items": [
|
||||
{
|
||||
"conversation_engine": "conversation_engine_1",
|
||||
"conversation_engine": conversation.OLD_HOME_ASSISTANT_AGENT,
|
||||
"conversation_language": "language_1",
|
||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
"id": id_1,
|
||||
"language": "language_1",
|
||||
"name": "name_1",
|
||||
"stt_engine": "stt_engine_1",
|
||||
|
@ -166,7 +168,7 @@ async def test_loading_pipelines_from_storage(
|
|||
"wake_word_id": "wakeword_id_3",
|
||||
},
|
||||
],
|
||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
"preferred_item": id_1,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -175,7 +177,8 @@ async def test_loading_pipelines_from_storage(
|
|||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 3
|
||||
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
assert store.async_get_preferred_item() == id_1
|
||||
assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
|
||||
async def test_migrate_pipeline_store(
|
||||
|
@ -262,7 +265,7 @@ async def test_create_default_pipeline(
|
|||
tts_engine_id="test",
|
||||
pipeline_name="Test pipeline",
|
||||
) == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language="en",
|
||||
id=ANY,
|
||||
language="en",
|
||||
|
@ -304,7 +307,7 @@ async def test_get_pipelines(hass: HomeAssistant) -> None:
|
|||
pipelines = async_get_pipelines(hass)
|
||||
assert list(pipelines) == [
|
||||
Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language="en",
|
||||
id=ANY,
|
||||
language="en",
|
||||
|
@ -351,7 +354,7 @@ async def test_default_pipeline_no_stt_tts(
|
|||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language=conv_language,
|
||||
id=pipeline.id,
|
||||
language=pipeline_language,
|
||||
|
@ -414,7 +417,7 @@ async def test_default_pipeline(
|
|||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language=conv_language,
|
||||
id=pipeline.id,
|
||||
language=pipeline_language,
|
||||
|
@ -445,7 +448,7 @@ async def test_default_pipeline_unsupported_stt_language(
|
|||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language="en",
|
||||
id=pipeline.id,
|
||||
language="en",
|
||||
|
@ -476,7 +479,7 @@ async def test_default_pipeline_unsupported_tts_language(
|
|||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language="en",
|
||||
id=pipeline.id,
|
||||
language="en",
|
||||
|
@ -502,7 +505,7 @@ async def test_update_pipeline(
|
|||
pipelines = list(pipelines)
|
||||
assert pipelines == [
|
||||
Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_engine="conversation.home_assistant",
|
||||
conversation_language="en",
|
||||
id=ANY,
|
||||
language="en",
|
||||
|
|
|
@ -1166,7 +1166,7 @@ async def test_get_pipeline(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_engine": "conversation.home_assistant",
|
||||
"conversation_language": "en",
|
||||
"id": ANY,
|
||||
"language": "en",
|
||||
|
@ -1250,7 +1250,7 @@ async def test_list_pipelines(
|
|||
assert msg["result"] == {
|
||||
"pipelines": [
|
||||
{
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_engine": "conversation.home_assistant",
|
||||
"conversation_language": "en",
|
||||
"id": ANY,
|
||||
"language": "en",
|
||||
|
@ -2012,7 +2012,7 @@ async def test_wake_word_cooldown_different_entities(
|
|||
await client_pipeline.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_engine": "conversation.home_assistant",
|
||||
"conversation_language": "en-US",
|
||||
"language": "en",
|
||||
"name": "pipeline_with_wake_word_1",
|
||||
|
@ -2032,7 +2032,7 @@ async def test_wake_word_cooldown_different_entities(
|
|||
await client_pipeline.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_engine": "conversation.home_assistant",
|
||||
"conversation_language": "en-US",
|
||||
"language": "en",
|
||||
"name": "pipeline_with_wake_word_2",
|
||||
|
|
|
@ -7,10 +7,12 @@ from homeassistant.components.cloud.assist_pipeline import (
|
|||
)
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
||||
async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None:
|
||||
"""Test migrate pipeline with invalid platform."""
|
||||
await async_setup_component(hass, "assist_pipeline", {})
|
||||
with pytest.raises(ValueError):
|
||||
await async_migrate_cloud_pipeline_engine(
|
||||
hass, Platform.BINARY_SENSOR, "test-engine-id"
|
||||
|
|
|
@ -231,6 +231,7 @@ async def test_login_view_create_pipeline(
|
|||
}
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
@ -270,6 +271,7 @@ async def test_login_view_create_pipeline_fail(
|
|||
}
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ import pytest
|
|||
from homeassistant.components import conversation
|
||||
from homeassistant.components.shopping_list import intent as sl_intent
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MockAgent
|
||||
|
||||
|
@ -14,7 +16,7 @@ from tests.common import MockConfigEntry
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_support_all(hass):
|
||||
def mock_agent_support_all(hass: HomeAssistant):
|
||||
"""Mock agent that supports all languages."""
|
||||
entry = MockConfigEntry(entry_id="mock-entry-support-all")
|
||||
entry.add_to_hass(hass)
|
||||
|
@ -34,7 +36,7 @@ def mock_shopping_list_io():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def sl_setup(hass):
|
||||
async def sl_setup(hass: HomeAssistant):
|
||||
"""Set up the shopping list."""
|
||||
|
||||
entry = MockConfigEntry(domain="shopping_list")
|
||||
|
@ -43,3 +45,10 @@ async def sl_setup(hass):
|
|||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
|
||||
await sl_intent.async_setup_intents(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(hass: HomeAssistant):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
|
|
@ -101,7 +101,7 @@
|
|||
# ---
|
||||
# name: test_get_agent_info
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
})
|
||||
# ---
|
||||
|
@ -113,7 +113,7 @@
|
|||
# ---
|
||||
# name: test_get_agent_info.2
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
})
|
||||
# ---
|
||||
|
@ -127,7 +127,7 @@
|
|||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
'supported_languages': list([
|
||||
'af',
|
||||
|
@ -207,7 +207,7 @@
|
|||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
'supported_languages': list([
|
||||
]),
|
||||
|
@ -231,7 +231,7 @@
|
|||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
'supported_languages': list([
|
||||
'en',
|
||||
|
@ -255,7 +255,7 @@
|
|||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
'supported_languages': list([
|
||||
'en',
|
||||
|
@ -279,7 +279,7 @@
|
|||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
'supported_languages': list([
|
||||
'de',
|
||||
|
@ -304,7 +304,7 @@
|
|||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'id': 'conversation.home_assistant',
|
||||
'name': 'Home Assistant',
|
||||
'supported_languages': list([
|
||||
'de-CH',
|
||||
|
@ -415,6 +415,36 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_http_processing_intent[conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
dict({
|
||||
'id': 'light.kitchen',
|
||||
'name': 'kitchen',
|
||||
'type': 'entity',
|
||||
}),
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Turned on the light',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_http_processing_intent[homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
|
@ -1035,6 +1065,36 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
dict({
|
||||
'id': 'light.kitchen',
|
||||
'name': 'kitchen',
|
||||
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
|
||||
}),
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Turned on the light',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn kitchen on-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
|
@ -1095,6 +1155,36 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
dict({
|
||||
'id': 'light.kitchen',
|
||||
'name': 'kitchen',
|
||||
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
|
||||
}),
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Turned on the light',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn on kitchen-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
|
@ -1155,6 +1245,36 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
dict({
|
||||
'id': 'light.kitchen',
|
||||
'name': 'kitchen',
|
||||
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
|
||||
}),
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Turned on the light',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
|
@ -1215,6 +1335,36 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
dict({
|
||||
'id': 'light.kitchen',
|
||||
'name': 'kitchen',
|
||||
'type': <IntentResponseTargetType.ENTITY: 'entity'>,
|
||||
}),
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Turned on the light',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
|
|
|
@ -7,7 +7,7 @@ from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult
|
|||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import agent_manager, default_agent
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.homeassistant.exposed_entities import (
|
||||
async_get_assistant_settings,
|
||||
)
|
||||
|
@ -152,9 +152,7 @@ async def test_conversation_agent(
|
|||
init_components,
|
||||
) -> None:
|
||||
"""Test DefaultAgent."""
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent(
|
||||
conversation.HOME_ASSISTANT_AGENT
|
||||
)
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
with patch(
|
||||
"homeassistant.components.conversation.default_agent.get_languages",
|
||||
return_value=["dwarvish", "elvish", "entish"],
|
||||
|
@ -181,6 +179,7 @@ async def test_expose_flag_automatically_set(
|
|||
|
||||
# After setting up conversation, the expose flag should now be set on all entities
|
||||
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
|
||||
"conversation.home_assistant": {"should_expose": False},
|
||||
light.entity_id: {"should_expose": True},
|
||||
test.entity_id: {"should_expose": False},
|
||||
}
|
||||
|
@ -190,6 +189,7 @@ async def test_expose_flag_automatically_set(
|
|||
hass.states.async_set(new_light, "test")
|
||||
await hass.async_block_till_done()
|
||||
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
|
||||
"conversation.home_assistant": {"should_expose": False},
|
||||
light.entity_id: {"should_expose": True},
|
||||
new_light: {"should_expose": True},
|
||||
test.entity_id: {"should_expose": False},
|
||||
|
@ -254,9 +254,7 @@ async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
|
|||
trigger_sentences = ["It's party time", "It is time to party"]
|
||||
trigger_response = "Cowabunga!"
|
||||
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent(
|
||||
conversation.HOME_ASSISTANT_AGENT
|
||||
)
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
callback = AsyncMock(return_value=trigger_response)
|
||||
|
|
47
tests/components/conversation/test_entity.py
Normal file
47
tests/components/conversation/test_entity.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
"""Tests for conversation entity."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.core import Context, HomeAssistant, State
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.common import mock_restore_cache
|
||||
|
||||
|
||||
async def test_state_set_and_restore(hass: HomeAssistant) -> None:
|
||||
"""Test we set and restore state in the integration."""
|
||||
entity_id = "conversation.home_assistant"
|
||||
timestamp = "2023-01-01T23:59:59+00:00"
|
||||
mock_restore_cache(hass, (State(entity_id, timestamp),))
|
||||
|
||||
await async_setup_component(hass, "homeassistant", {})
|
||||
await async_setup_component(hass, "conversation", {})
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == timestamp
|
||||
|
||||
now = dt_util.utcnow()
|
||||
context = Context()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process"
|
||||
) as mock_process,
|
||||
patch("homeassistant.util.dt.utcnow", return_value=now),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{"text": "Hello"},
|
||||
context=context,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(mock_process.mock_calls) == 1
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == now.isoformat()
|
||||
assert state.context is context
|
|
@ -9,7 +9,7 @@ from syrupy.assertion import SnapshotAssertion
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import agent_manager, default_agent
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.conversation.models import ConversationInput
|
||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
|
@ -35,7 +35,13 @@ from tests.common import (
|
|||
from tests.components.light.common import MockLight
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
AGENT_ID_OPTIONS = [None, conversation.HOME_ASSISTANT_AGENT]
|
||||
AGENT_ID_OPTIONS = [
|
||||
None,
|
||||
# Old value of conversation.HOME_ASSISTANT_AGENT,
|
||||
"homeassistant",
|
||||
# Current value of conversation.HOME_ASSISTANT_AGENT,
|
||||
"conversation.home_assistant",
|
||||
]
|
||||
|
||||
|
||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||
|
@ -51,14 +57,6 @@ class OrderBeerIntentHandler(intent.IntentHandler):
|
|||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(hass):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
||||
async def test_http_processing_intent(
|
||||
hass: HomeAssistant,
|
||||
|
@ -752,7 +750,7 @@ async def test_ws_prepare(
|
|||
"""Test the Websocket prepare conversation API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
# No intents should be loaded yet
|
||||
|
@ -854,7 +852,7 @@ async def test_prepare_reload(hass: HomeAssistant) -> None:
|
|||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
# Load intents
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
await agent.async_prepare(language)
|
||||
|
||||
|
@ -882,7 +880,7 @@ async def test_prepare_fail(hass: HomeAssistant) -> None:
|
|||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
# Load intents
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
await agent.async_prepare("not-a-language")
|
||||
|
||||
|
@ -919,7 +917,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
|
|||
hass.states.async_set("cover.front_door", "closed")
|
||||
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
|
||||
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
result = await agent.async_process(
|
||||
|
@ -1063,12 +1061,15 @@ async def test_light_area_same_name(
|
|||
assert call.data == {"entity_id": [kitchen_light.entity_id]}
|
||||
|
||||
|
||||
async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
|
||||
async def test_agent_id_validator_invalid_agent(
|
||||
hass: HomeAssistant, init_components
|
||||
) -> None:
|
||||
"""Test validating agent id."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
conversation.agent_id_validator("invalid_agent")
|
||||
|
||||
conversation.agent_id_validator(conversation.HOME_ASSISTANT_AGENT)
|
||||
conversation.agent_id_validator("conversation.home_assistant")
|
||||
|
||||
|
||||
async def test_get_agent_list(
|
||||
|
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.conversation import agent_manager, default_agent
|
||||
from homeassistant.components.conversation import default_agent
|
||||
from homeassistant.components.conversation.models import ConversationInput
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import trigger
|
||||
|
@ -515,7 +515,7 @@ async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
|
|||
},
|
||||
)
|
||||
|
||||
agent = await agent_manager.get_agent_manager(hass).async_get_agent()
|
||||
agent = default_agent.async_get_default_agent(hass)
|
||||
assert isinstance(agent, default_agent.DefaultAgent)
|
||||
|
||||
result = await agent.async_process(
|
||||
|
|
|
@ -327,6 +327,7 @@ async def test_conversation_agent(
|
|||
"""Test GoogleAssistantConversationAgent."""
|
||||
await setup_integration()
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
|
@ -334,7 +335,7 @@ async def test_conversation_agent(
|
|||
entry = entries[0]
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
agent = await conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
|
||||
|
||||
text1 = "tell me a joke"
|
||||
|
@ -365,6 +366,7 @@ async def test_conversation_agent_refresh_token(
|
|||
"""Test GoogleAssistantConversationAgent when token is expired."""
|
||||
await setup_integration()
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
|
@ -416,6 +418,7 @@ async def test_conversation_agent_language_changed(
|
|||
"""Test GoogleAssistantConversationAgent when language is changed."""
|
||||
await setup_integration()
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
|
|
|
@ -4,6 +4,8 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
@ -23,10 +25,17 @@ def mock_config_entry(hass):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(hass, mock_config_entry):
|
||||
async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry):
|
||||
"""Initialize integration."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
with patch("google.generativeai.get_model"):
|
||||
assert await async_setup_component(
|
||||
hass, "google_generative_ai_conversation", {}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_ha(hass: HomeAssistant) -> None:
|
||||
"""Set up Home Assistant."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
|
|
@ -10,6 +10,7 @@ from homeassistant.components import conversation
|
|||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -124,6 +125,7 @@ async def test_template_error(
|
|||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
|
@ -152,7 +154,7 @@ async def test_conversation_agent(
|
|||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test GoogleGenerativeAIAgent."""
|
||||
agent = await conversation.get_agent_manager(hass).async_get_agent(
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == "*"
|
||||
|
|
|
@ -1033,7 +1033,7 @@ async def test_webhook_handle_conversation_process(
|
|||
webhook_client.server.app.router._frozen = False
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.agent_manager.AgentManager.async_get_agent",
|
||||
"homeassistant.components.conversation.agent_manager.async_get_agent",
|
||||
return_value=mock_conversation_agent,
|
||||
):
|
||||
resp = await webhook_client.post(
|
||||
|
|
|
@ -35,3 +35,9 @@ async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfig
|
|||
):
|
||||
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_ha(hass: HomeAssistant) -> None:
|
||||
"""Set up Home Assistant."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
|
|
@ -229,7 +229,7 @@ async def test_message_history_pruning(
|
|||
assert isinstance(result.conversation_id, str)
|
||||
conversation_ids.append(result.conversation_id)
|
||||
|
||||
agent = await conversation.get_agent_manager(hass).async_get_agent(
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert isinstance(agent, ollama.OllamaAgent)
|
||||
|
@ -284,7 +284,7 @@ async def test_message_history_unlimited(
|
|||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
agent = await conversation.get_agent_manager(hass).async_get_agent(
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert isinstance(agent, ollama.OllamaAgent)
|
||||
|
@ -340,7 +340,7 @@ async def test_conversation_agent(
|
|||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test OllamaAgent."""
|
||||
agent = await conversation.get_agent_manager(hass).async_get_agent(
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == MATCH_ALL
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
@ -30,3 +31,9 @@ async def mock_init_component(hass, mock_config_entry):
|
|||
):
|
||||
assert await async_setup_component(hass, "openai_conversation", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_ha(hass: HomeAssistant) -> None:
|
||||
"""Set up Home Assistant."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
|
|
@ -194,7 +194,7 @@ async def test_conversation_agent(
|
|||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test OpenAIAgent."""
|
||||
agent = await conversation.get_agent_manager(hass).async_get_agent(
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == "*"
|
||||
|
|
Loading…
Add table
Reference in a new issue