Pipelines to default to Home Assistant agent (#91321)
* Pipelines to default to Home Assistant agent * Tests fix
This commit is contained in:
parent
0678ab4e45
commit
c9d81bd217
4 changed files with 14 additions and 10 deletions
|
@ -18,10 +18,12 @@ from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
|
||||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
|
from .const import HOME_ASSISTANT_AGENT
|
||||||
from .default_agent import DefaultAgent
|
from .default_agent import DefaultAgent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
|
"HOME_ASSISTANT_AGENT",
|
||||||
"async_converse",
|
"async_converse",
|
||||||
"async_get_agent_info",
|
"async_get_agent_info",
|
||||||
"async_set_agent",
|
"async_set_agent",
|
||||||
|
@ -333,8 +335,6 @@ async def async_converse(
|
||||||
class AgentManager:
|
class AgentManager:
|
||||||
"""Class to manage conversation agents."""
|
"""Class to manage conversation agents."""
|
||||||
|
|
||||||
HOME_ASSISTANT_AGENT = "homeassistant"
|
|
||||||
|
|
||||||
default_agent: str = HOME_ASSISTANT_AGENT
|
default_agent: str = HOME_ASSISTANT_AGENT
|
||||||
_builtin_agent: AbstractConversationAgent | None = None
|
_builtin_agent: AbstractConversationAgent | None = None
|
||||||
|
|
||||||
|
@ -351,7 +351,7 @@ class AgentManager:
|
||||||
if agent_id is None:
|
if agent_id is None:
|
||||||
agent_id = self.default_agent
|
agent_id = self.default_agent
|
||||||
|
|
||||||
if agent_id == AgentManager.HOME_ASSISTANT_AGENT:
|
if agent_id == HOME_ASSISTANT_AGENT:
|
||||||
if self._builtin_agent is not None:
|
if self._builtin_agent is not None:
|
||||||
return self._builtin_agent
|
return self._builtin_agent
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ class AgentManager:
|
||||||
"""List all agents."""
|
"""List all agents."""
|
||||||
agents: list[AgentInfo] = [
|
agents: list[AgentInfo] = [
|
||||||
{
|
{
|
||||||
"id": AgentManager.HOME_ASSISTANT_AGENT,
|
"id": HOME_ASSISTANT_AGENT,
|
||||||
"name": "Home Assistant",
|
"name": "Home Assistant",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -401,18 +401,18 @@ class AgentManager:
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_is_valid_agent_id(self, agent_id: str) -> bool:
|
def async_is_valid_agent_id(self, agent_id: str) -> bool:
|
||||||
"""Check if the agent id is valid."""
|
"""Check if the agent id is valid."""
|
||||||
return agent_id in self._agents or agent_id == AgentManager.HOME_ASSISTANT_AGENT
|
return agent_id in self._agents or agent_id == HOME_ASSISTANT_AGENT
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
||||||
"""Set the agent."""
|
"""Set the agent."""
|
||||||
self._agents[agent_id] = agent
|
self._agents[agent_id] = agent
|
||||||
if self.default_agent == AgentManager.HOME_ASSISTANT_AGENT:
|
if self.default_agent == HOME_ASSISTANT_AGENT:
|
||||||
self.default_agent = agent_id
|
self.default_agent = agent_id
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_unset_agent(self, agent_id: str) -> None:
|
def async_unset_agent(self, agent_id: str) -> None:
|
||||||
"""Unset the agent."""
|
"""Unset the agent."""
|
||||||
if self.default_agent == agent_id:
|
if self.default_agent == agent_id:
|
||||||
self.default_agent = AgentManager.HOME_ASSISTANT_AGENT
|
self.default_agent = HOME_ASSISTANT_AGENT
|
||||||
self._agents.pop(agent_id, None)
|
self._agents.pop(agent_id, None)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
"""Const for conversation integration."""
|
"""Const for conversation integration."""
|
||||||
|
|
||||||
DOMAIN = "conversation"
|
DOMAIN = "conversation"
|
||||||
|
HOME_ASSISTANT_AGENT = "homeassistant"
|
||||||
|
|
|
@ -289,7 +289,10 @@ class PipelineRun:
|
||||||
async def prepare_recognize_intent(self) -> None:
|
async def prepare_recognize_intent(self) -> None:
|
||||||
"""Prepare recognizing an intent."""
|
"""Prepare recognizing an intent."""
|
||||||
agent_info = conversation.async_get_agent_info(
|
agent_info = conversation.async_get_agent_info(
|
||||||
self.hass, self.pipeline.conversation_engine
|
self.hass,
|
||||||
|
# If no conversation engine is set, use the Home Assistant agent
|
||||||
|
# (the conversation integration default is currently the last one set)
|
||||||
|
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent_info is None:
|
if agent_info is None:
|
||||||
|
|
|
@ -25,7 +25,7 @@ from . import expose_entity, expose_new
|
||||||
from tests.common import MockConfigEntry, MockUser, async_mock_service
|
from tests.common import MockConfigEntry, MockUser, async_mock_service
|
||||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
AGENT_ID_OPTIONS = [None, conversation.AgentManager.HOME_ASSISTANT_AGENT]
|
AGENT_ID_OPTIONS = [None, conversation.HOME_ASSISTANT_AGENT]
|
||||||
|
|
||||||
|
|
||||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||||
|
@ -1569,7 +1569,7 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
conversation.agent_id_validator("invalid_agent")
|
conversation.agent_id_validator("invalid_agent")
|
||||||
|
|
||||||
conversation.agent_id_validator(conversation.AgentManager.HOME_ASSISTANT_AGENT)
|
conversation.agent_id_validator(conversation.HOME_ASSISTANT_AGENT)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_agent_list(
|
async def test_get_agent_list(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue