Conversation cleanup (#86592)
* Require config entry when setting Conversation agent, add new unset agent method * Remove onboarding from conversation agent * Type attribution * Wrap async_process params in ConversationInput object
This commit is contained in:
parent
5c6656dcac
commit
6c8efe3a3b
8 changed files with 97 additions and 127 deletions
|
@ -22,7 +22,7 @@ from homeassistant.const import (
|
|||
CONF_TYPE,
|
||||
EVENT_HOMEASSISTANT_START,
|
||||
)
|
||||
from homeassistant.core import Context, CoreState, HomeAssistant
|
||||
from homeassistant.core import CoreState, HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import (
|
||||
aiohttp_client,
|
||||
|
@ -147,7 +147,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, almond_hass_start)
|
||||
|
||||
conversation.async_set_agent(hass, agent)
|
||||
conversation.async_set_agent(hass, entry, agent)
|
||||
return True
|
||||
|
||||
|
||||
|
@ -223,7 +223,7 @@ async def _configure_almond_for_ha(
|
|||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Almond."""
|
||||
conversation.async_set_agent(hass, None)
|
||||
conversation.async_unset_agent(hass, entry)
|
||||
return True
|
||||
|
||||
|
||||
|
@ -264,40 +264,13 @@ class AlmondAgent(conversation.AbstractConversationAgent):
|
|||
"""Return the attribution."""
|
||||
return {"name": "Powered by Almond", "url": "https://almond.stanford.edu/"}
|
||||
|
||||
async def async_get_onboarding(self):
|
||||
"""Get onboard url if not onboarded."""
|
||||
if self.entry.data.get("onboarded"):
|
||||
return None
|
||||
|
||||
host = self.entry.data["host"]
|
||||
if self.entry.data.get("is_hassio"):
|
||||
host = "/core_almond"
|
||||
return {
|
||||
"text": (
|
||||
"Would you like to opt-in to share your anonymized commands with"
|
||||
" Stanford to improve Almond's responses?"
|
||||
),
|
||||
"url": f"{host}/conversation",
|
||||
}
|
||||
|
||||
async def async_set_onboarding(self, shown):
|
||||
"""Set onboarding status."""
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.entry, data={**self.entry.data, "onboarded": shown}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def async_process(
|
||||
self,
|
||||
text: str,
|
||||
context: Context,
|
||||
conversation_id: str | None = None,
|
||||
language: str | None = None,
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
response = await self.api.async_converse_text(text, conversation_id)
|
||||
language = language or self.hass.config.language
|
||||
response = await self.api.async_converse_text(
|
||||
user_input.text, user_input.conversation_id
|
||||
)
|
||||
|
||||
first_choice = True
|
||||
buffer = ""
|
||||
|
@ -318,8 +291,8 @@ class AlmondAgent(conversation.AbstractConversationAgent):
|
|||
buffer += ","
|
||||
buffer += f" {message['title']}"
|
||||
|
||||
intent_response = intent.IntentResponse(language=language)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(buffer.strip())
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
|
|
@ -10,12 +10,13 @@ import voluptuous as vol
|
|||
from homeassistant import core
|
||||
from homeassistant.components import http, websocket_api
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv, intent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationResult
|
||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .default_agent import DefaultAgent
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -62,11 +63,25 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
|
||||
@core.callback
|
||||
@bind_hass
|
||||
def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None):
|
||||
def async_set_agent(
|
||||
hass: core.HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
agent: AbstractConversationAgent,
|
||||
):
|
||||
"""Set the agent to handle the conversations."""
|
||||
hass.data[DATA_AGENT] = agent
|
||||
|
||||
|
||||
@core.callback
|
||||
@bind_hass
|
||||
def async_unset_agent(
|
||||
hass: core.HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
):
|
||||
"""Set the agent to handle the conversations."""
|
||||
hass.data[DATA_AGENT] = None
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
if config_intents := config.get(DOMAIN, {}).get("intents"):
|
||||
|
@ -79,7 +94,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
agent = await _get_agent(hass)
|
||||
try:
|
||||
await agent.async_process(
|
||||
text, service.context, language=service.data.get(ATTR_LANGUAGE)
|
||||
ConversationInput(
|
||||
text=text,
|
||||
context=service.context,
|
||||
conversation_id=None,
|
||||
language=service.data.get(ATTR_LANGUAGE, hass.config.language),
|
||||
)
|
||||
)
|
||||
except intent.IntentHandleError as err:
|
||||
_LOGGER.error("Error processing %s: %s", text, err)
|
||||
|
@ -99,7 +119,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
websocket_api.async_register_command(hass, websocket_process)
|
||||
websocket_api.async_register_command(hass, websocket_prepare)
|
||||
websocket_api.async_register_command(hass, websocket_get_agent_info)
|
||||
websocket_api.async_register_command(hass, websocket_set_onboarding)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -158,41 +177,17 @@ async def websocket_get_agent_info(
|
|||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Do we need onboarding."""
|
||||
"""Info about the agent in use."""
|
||||
agent = await _get_agent(hass)
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"onboarding": await agent.async_get_onboarding(),
|
||||
"attribution": agent.attribution,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "conversation/onboarding/set",
|
||||
vol.Required("shown"): bool,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_set_onboarding(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set onboarding status."""
|
||||
agent = await _get_agent(hass)
|
||||
|
||||
success = await agent.async_set_onboarding(msg["shown"])
|
||||
|
||||
if success:
|
||||
connection.send_result(msg["id"])
|
||||
else:
|
||||
connection.send_error(msg["id"], "error", "Failed to set onboarding")
|
||||
|
||||
|
||||
class ConversationProcessView(http.HomeAssistantView):
|
||||
"""View to process text."""
|
||||
|
||||
|
@ -242,5 +237,12 @@ async def async_converse(
|
|||
if language is None:
|
||||
language = hass.config.language
|
||||
|
||||
result = await agent.async_process(text, context, conversation_id, language)
|
||||
result = await agent.async_process(
|
||||
ConversationInput(
|
||||
text=text,
|
||||
context=context,
|
||||
conversation_id=conversation_id,
|
||||
language=language,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -3,12 +3,22 @@ from __future__ import annotations
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import intent
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationInput:
|
||||
"""User input to be processed."""
|
||||
|
||||
text: str
|
||||
context: Context
|
||||
conversation_id: str | None
|
||||
language: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationResult:
|
||||
"""Result of async_process."""
|
||||
|
@ -24,34 +34,27 @@ class ConversationResult:
|
|||
}
|
||||
|
||||
|
||||
class Attribution(TypedDict):
|
||||
"""Attribution for a conversation agent."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
|
||||
|
||||
class AbstractConversationAgent(ABC):
|
||||
"""Abstract conversation agent."""
|
||||
|
||||
@property
|
||||
def attribution(self):
|
||||
def attribution(self) -> Attribution | None:
|
||||
"""Return the attribution."""
|
||||
return None
|
||||
|
||||
async def async_get_onboarding(self):
|
||||
"""Get onboard data."""
|
||||
return None
|
||||
|
||||
async def async_set_onboarding(self, shown):
|
||||
"""Set onboard data."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
async def async_process(
|
||||
self,
|
||||
text: str,
|
||||
context: Context,
|
||||
conversation_id: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> ConversationResult:
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
|
||||
async def async_reload(self, language: str | None = None):
|
||||
async def async_reload(self, language: str | None = None) -> None:
|
||||
"""Clear cached intents for a language."""
|
||||
|
||||
async def async_prepare(self, language: str | None = None):
|
||||
async def async_prepare(self, language: str | None = None) -> None:
|
||||
"""Load intents for a language."""
|
||||
|
|
|
@ -20,7 +20,7 @@ from homeassistant import core, setup
|
|||
from homeassistant.helpers import area_registry, entity_registry, intent, template
|
||||
from homeassistant.helpers.json import json_loads
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationResult
|
||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -81,16 +81,11 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
if config_intents:
|
||||
self._config_intents = config_intents
|
||||
|
||||
async def async_process(
|
||||
self,
|
||||
text: str,
|
||||
context: core.Context,
|
||||
conversation_id: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> ConversationResult:
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
language = language or self.hass.config.language
|
||||
language = user_input.language or self.hass.config.language
|
||||
lang_intents = self._lang_intents.get(language)
|
||||
conversation_id = None # Not supported
|
||||
|
||||
# Reload intents if missing or new components
|
||||
if lang_intents is None or (
|
||||
|
@ -114,9 +109,9 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
"name": self._make_names_list(),
|
||||
}
|
||||
|
||||
result = recognize(text, lang_intents.intents, slot_lists=slot_lists)
|
||||
result = recognize(user_input.text, lang_intents.intents, slot_lists=slot_lists)
|
||||
if result is None:
|
||||
_LOGGER.debug("No intent was matched for '%s'", text)
|
||||
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
|
||||
return _make_error_result(
|
||||
language,
|
||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
||||
|
@ -133,8 +128,8 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
entity.name: {"value": entity.value}
|
||||
for entity in result.entities_list
|
||||
},
|
||||
text,
|
||||
context,
|
||||
user_input.text,
|
||||
user_input.context,
|
||||
language,
|
||||
)
|
||||
except intent.IntentHandleError:
|
||||
|
|
|
@ -9,7 +9,7 @@ import voluptuous as vol
|
|||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
|
||||
from homeassistant.core import Context, HomeAssistant, ServiceCall
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers import config_validation as cv, discovery, intent
|
||||
from homeassistant.helpers.config_entry_oauth2_flow import (
|
||||
|
@ -124,9 +124,9 @@ async def update_listener(hass, entry):
|
|||
"""Handle options update."""
|
||||
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
|
||||
agent = GoogleAssistantConversationAgent(hass, entry)
|
||||
conversation.async_set_agent(hass, agent)
|
||||
conversation.async_set_agent(hass, entry, agent)
|
||||
else:
|
||||
conversation.async_set_agent(hass, None)
|
||||
conversation.async_unset_agent(hass, entry)
|
||||
|
||||
|
||||
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||
|
@ -148,11 +148,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
|||
}
|
||||
|
||||
async def async_process(
|
||||
self,
|
||||
text: str,
|
||||
context: Context,
|
||||
conversation_id: str | None = None,
|
||||
language: str | None = None,
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
if self.session:
|
||||
|
@ -170,12 +166,11 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
|||
)
|
||||
self.assistant = TextAssistant(credentials, language_code)
|
||||
|
||||
resp = self.assistant.assist(text)
|
||||
resp = self.assistant.assist(user_input.text)
|
||||
text_response = resp[0]
|
||||
|
||||
language = language or self.hass.config.language
|
||||
intent_response = intent.IntentResponse(language=language)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(text_response)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import intent
|
||||
|
||||
|
||||
|
@ -15,16 +14,12 @@ class MockAgent(conversation.AbstractConversationAgent):
|
|||
self.response = "Test response"
|
||||
|
||||
async def async_process(
|
||||
self,
|
||||
text: str,
|
||||
context: Context,
|
||||
conversation_id: str | None = None,
|
||||
language: str | None = None,
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process some text."""
|
||||
self.calls.append((text, context, conversation_id, language))
|
||||
response = intent.IntentResponse(language=language)
|
||||
self.calls.append(user_input)
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
response.async_set_speech(self.response)
|
||||
return conversation.ConversationResult(
|
||||
response=response, conversation_id=conversation_id
|
||||
response=response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
|
|
@ -11,5 +11,5 @@ from . import MockAgent
|
|||
def mock_agent(hass):
|
||||
"""Mock agent."""
|
||||
agent = MockAgent()
|
||||
conversation.async_set_agent(hass, agent)
|
||||
conversation.async_set_agent(hass, None, agent)
|
||||
return agent
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""The tests for the Conversation component."""
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -295,10 +295,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
|
|||
}
|
||||
|
||||
assert len(mock_agent.calls) == 1
|
||||
assert mock_agent.calls[0][0] == "Test Text"
|
||||
assert mock_agent.calls[0][1].user_id == hass_admin_user.id
|
||||
assert mock_agent.calls[0][2] == "test-conv-id"
|
||||
assert mock_agent.calls[0][3] == "test-language"
|
||||
assert mock_agent.calls[0].text == "Test Text"
|
||||
assert mock_agent.calls[0].context.user_id == hass_admin_user.id
|
||||
assert mock_agent.calls[0].conversation_id == "test-conv-id"
|
||||
assert mock_agent.calls[0].language == "test-language"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -349,7 +349,7 @@ async def test_ws_api(hass, hass_ws_client, payload):
|
|||
"language": payload.get("language", hass.config.language),
|
||||
"data": {"code": "no_intent_match"},
|
||||
},
|
||||
"conversation_id": payload.get("conversation_id") or ANY,
|
||||
"conversation_id": None,
|
||||
}
|
||||
|
||||
|
||||
|
@ -560,5 +560,12 @@ async def test_non_default_response(hass, init_components):
|
|||
agent = await conversation._get_agent(hass)
|
||||
assert isinstance(agent, conversation.DefaultAgent)
|
||||
|
||||
result = await agent.async_process("open the front door", Context())
|
||||
result = await agent.async_process(
|
||||
conversation.ConversationInput(
|
||||
text="open the front door",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
language=hass.config.language,
|
||||
)
|
||||
)
|
||||
assert result.response.speech["plain"]["speech"] == "Opened front door"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue