Conversation cleanup (#86592)

* Require config entry when setting Conversation agent, add new unset agent method

* Remove onboarding from conversation agent

* Type attribution

* Wrap async_process params in ConversationInput object
This commit is contained in:
Paulus Schoutsen 2023-01-24 22:47:49 -05:00 committed by GitHub
parent 5c6656dcac
commit 6c8efe3a3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 97 additions and 127 deletions

View file

@ -22,7 +22,7 @@ from homeassistant.const import (
CONF_TYPE, CONF_TYPE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_START,
) )
from homeassistant.core import Context, CoreState, HomeAssistant from homeassistant.core import CoreState, HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import ( from homeassistant.helpers import (
aiohttp_client, aiohttp_client,
@ -147,7 +147,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, almond_hass_start) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, almond_hass_start)
conversation.async_set_agent(hass, agent) conversation.async_set_agent(hass, entry, agent)
return True return True
@ -223,7 +223,7 @@ async def _configure_almond_for_ha(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Almond.""" """Unload Almond."""
conversation.async_set_agent(hass, None) conversation.async_unset_agent(hass, entry)
return True return True
@ -264,40 +264,13 @@ class AlmondAgent(conversation.AbstractConversationAgent):
"""Return the attribution.""" """Return the attribution."""
return {"name": "Powered by Almond", "url": "https://almond.stanford.edu/"} return {"name": "Powered by Almond", "url": "https://almond.stanford.edu/"}
async def async_get_onboarding(self):
"""Get onboard url if not onboarded."""
if self.entry.data.get("onboarded"):
return None
host = self.entry.data["host"]
if self.entry.data.get("is_hassio"):
host = "/core_almond"
return {
"text": (
"Would you like to opt-in to share your anonymized commands with"
" Stanford to improve Almond's responses?"
),
"url": f"{host}/conversation",
}
async def async_set_onboarding(self, shown):
"""Set onboarding status."""
self.hass.config_entries.async_update_entry(
self.entry, data={**self.entry.data, "onboarded": shown}
)
return True
async def async_process( async def async_process(
self, self, user_input: conversation.ConversationInput
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
response = await self.api.async_converse_text(text, conversation_id) response = await self.api.async_converse_text(
language = language or self.hass.config.language user_input.text, user_input.conversation_id
)
first_choice = True first_choice = True
buffer = "" buffer = ""
@ -318,8 +291,8 @@ class AlmondAgent(conversation.AbstractConversationAgent):
buffer += "," buffer += ","
buffer += f" {message['title']}" buffer += f" {message['title']}"
intent_response = intent.IntentResponse(language=language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(buffer.strip()) intent_response.async_set_speech(buffer.strip())
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=user_input.conversation_id
) )

View file

@ -10,12 +10,13 @@ import voluptuous as vol
from homeassistant import core from homeassistant import core
from homeassistant.components import http, websocket_api from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from .agent import AbstractConversationAgent, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .default_agent import DefaultAgent from .default_agent import DefaultAgent
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -62,11 +63,25 @@ CONFIG_SCHEMA = vol.Schema(
@core.callback @core.callback
@bind_hass @bind_hass
def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None): def async_set_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
agent: AbstractConversationAgent,
):
"""Set the agent to handle the conversations.""" """Set the agent to handle the conversations."""
hass.data[DATA_AGENT] = agent hass.data[DATA_AGENT] = agent
@core.callback
@bind_hass
def async_unset_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
):
"""Set the agent to handle the conversations."""
hass.data[DATA_AGENT] = None
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
if config_intents := config.get(DOMAIN, {}).get("intents"): if config_intents := config.get(DOMAIN, {}).get("intents"):
@ -79,7 +94,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
agent = await _get_agent(hass) agent = await _get_agent(hass)
try: try:
await agent.async_process( await agent.async_process(
text, service.context, language=service.data.get(ATTR_LANGUAGE) ConversationInput(
text=text,
context=service.context,
conversation_id=None,
language=service.data.get(ATTR_LANGUAGE, hass.config.language),
)
) )
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
_LOGGER.error("Error processing %s: %s", text, err) _LOGGER.error("Error processing %s: %s", text, err)
@ -99,7 +119,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, websocket_process) websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare) websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_get_agent_info) websocket_api.async_register_command(hass, websocket_get_agent_info)
websocket_api.async_register_command(hass, websocket_set_onboarding)
return True return True
@ -158,41 +177,17 @@ async def websocket_get_agent_info(
connection: websocket_api.ActiveConnection, connection: websocket_api.ActiveConnection,
msg: dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Do we need onboarding.""" """Info about the agent in use."""
agent = await _get_agent(hass) agent = await _get_agent(hass)
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{ {
"onboarding": await agent.async_get_onboarding(),
"attribution": agent.attribution, "attribution": agent.attribution,
}, },
) )
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/onboarding/set",
vol.Required("shown"): bool,
}
)
@websocket_api.async_response
async def websocket_set_onboarding(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set onboarding status."""
agent = await _get_agent(hass)
success = await agent.async_set_onboarding(msg["shown"])
if success:
connection.send_result(msg["id"])
else:
connection.send_error(msg["id"], "error", "Failed to set onboarding")
class ConversationProcessView(http.HomeAssistantView): class ConversationProcessView(http.HomeAssistantView):
"""View to process text.""" """View to process text."""
@ -242,5 +237,12 @@ async def async_converse(
if language is None: if language is None:
language = hass.config.language language = hass.config.language
result = await agent.async_process(text, context, conversation_id, language) result = await agent.async_process(
ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
language=language,
)
)
return result return result

View file

@ -3,12 +3,22 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, TypedDict
from homeassistant.core import Context from homeassistant.core import Context
from homeassistant.helpers import intent from homeassistant.helpers import intent
@dataclass
class ConversationInput:
"""User input to be processed."""
text: str
context: Context
conversation_id: str | None
language: str
@dataclass @dataclass
class ConversationResult: class ConversationResult:
"""Result of async_process.""" """Result of async_process."""
@ -24,34 +34,27 @@ class ConversationResult:
} }
class Attribution(TypedDict):
"""Attribution for a conversation agent."""
name: str
url: str
class AbstractConversationAgent(ABC): class AbstractConversationAgent(ABC):
"""Abstract conversation agent.""" """Abstract conversation agent."""
@property @property
def attribution(self): def attribution(self) -> Attribution | None:
"""Return the attribution.""" """Return the attribution."""
return None return None
async def async_get_onboarding(self):
"""Get onboard data."""
return None
async def async_set_onboarding(self, shown):
"""Set onboard data."""
return True
@abstractmethod @abstractmethod
async def async_process( async def async_process(self, user_input: ConversationInput) -> ConversationResult:
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> ConversationResult:
"""Process a sentence.""" """Process a sentence."""
async def async_reload(self, language: str | None = None): async def async_reload(self, language: str | None = None) -> None:
"""Clear cached intents for a language.""" """Clear cached intents for a language."""
async def async_prepare(self, language: str | None = None): async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language.""" """Load intents for a language."""

View file

@ -20,7 +20,7 @@ from homeassistant import core, setup
from homeassistant.helpers import area_registry, entity_registry, intent, template from homeassistant.helpers import area_registry, entity_registry, intent, template
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import json_loads
from .agent import AbstractConversationAgent, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN from .const import DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -81,16 +81,11 @@ class DefaultAgent(AbstractConversationAgent):
if config_intents: if config_intents:
self._config_intents = config_intents self._config_intents = config_intents
async def async_process( async def async_process(self, user_input: ConversationInput) -> ConversationResult:
self,
text: str,
context: core.Context,
conversation_id: str | None = None,
language: str | None = None,
) -> ConversationResult:
"""Process a sentence.""" """Process a sentence."""
language = language or self.hass.config.language language = user_input.language or self.hass.config.language
lang_intents = self._lang_intents.get(language) lang_intents = self._lang_intents.get(language)
conversation_id = None # Not supported
# Reload intents if missing or new components # Reload intents if missing or new components
if lang_intents is None or ( if lang_intents is None or (
@ -114,9 +109,9 @@ class DefaultAgent(AbstractConversationAgent):
"name": self._make_names_list(), "name": self._make_names_list(),
} }
result = recognize(text, lang_intents.intents, slot_lists=slot_lists) result = recognize(user_input.text, lang_intents.intents, slot_lists=slot_lists)
if result is None: if result is None:
_LOGGER.debug("No intent was matched for '%s'", text) _LOGGER.debug("No intent was matched for '%s'", user_input.text)
return _make_error_result( return _make_error_result(
language, language,
intent.IntentResponseErrorCode.NO_INTENT_MATCH, intent.IntentResponseErrorCode.NO_INTENT_MATCH,
@ -133,8 +128,8 @@ class DefaultAgent(AbstractConversationAgent):
entity.name: {"value": entity.value} entity.name: {"value": entity.value}
for entity in result.entities_list for entity in result.entities_list
}, },
text, user_input.text,
context, user_input.context,
language, language,
) )
except intent.IntentHandleError: except intent.IntentHandleError:

View file

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv, discovery, intent from homeassistant.helpers import config_validation as cv, discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import ( from homeassistant.helpers.config_entry_oauth2_flow import (
@ -124,9 +124,9 @@ async def update_listener(hass, entry):
"""Handle options update.""" """Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False): if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry) agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, agent) conversation.async_set_agent(hass, entry, agent)
else: else:
conversation.async_set_agent(hass, None) conversation.async_unset_agent(hass, entry)
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
@ -148,11 +148,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
} }
async def async_process( async def async_process(
self, self, user_input: conversation.ConversationInput
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
if self.session: if self.session:
@ -170,12 +166,11 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
) )
self.assistant = TextAssistant(credentials, language_code) self.assistant = TextAssistant(credentials, language_code)
resp = self.assistant.assist(text) resp = self.assistant.assist(user_input.text)
text_response = resp[0] text_response = resp[0]
language = language or self.hass.config.language intent_response = intent.IntentResponse(language=user_input.language)
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_speech(text_response) intent_response.async_set_speech(text_response)
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=user_input.conversation_id
) )

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.core import Context
from homeassistant.helpers import intent from homeassistant.helpers import intent
@ -15,16 +14,12 @@ class MockAgent(conversation.AbstractConversationAgent):
self.response = "Test response" self.response = "Test response"
async def async_process( async def async_process(
self, self, user_input: conversation.ConversationInput
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process some text.""" """Process some text."""
self.calls.append((text, context, conversation_id, language)) self.calls.append(user_input)
response = intent.IntentResponse(language=language) response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(self.response) response.async_set_speech(self.response)
return conversation.ConversationResult( return conversation.ConversationResult(
response=response, conversation_id=conversation_id response=response, conversation_id=user_input.conversation_id
) )

View file

@ -11,5 +11,5 @@ from . import MockAgent
def mock_agent(hass): def mock_agent(hass):
"""Mock agent.""" """Mock agent."""
agent = MockAgent() agent = MockAgent()
conversation.async_set_agent(hass, agent) conversation.async_set_agent(hass, None, agent)
return agent return agent

View file

@ -1,6 +1,6 @@
"""The tests for the Conversation component.""" """The tests for the Conversation component."""
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import ANY, patch from unittest.mock import patch
import pytest import pytest
@ -295,10 +295,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
} }
assert len(mock_agent.calls) == 1 assert len(mock_agent.calls) == 1
assert mock_agent.calls[0][0] == "Test Text" assert mock_agent.calls[0].text == "Test Text"
assert mock_agent.calls[0][1].user_id == hass_admin_user.id assert mock_agent.calls[0].context.user_id == hass_admin_user.id
assert mock_agent.calls[0][2] == "test-conv-id" assert mock_agent.calls[0].conversation_id == "test-conv-id"
assert mock_agent.calls[0][3] == "test-language" assert mock_agent.calls[0].language == "test-language"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -349,7 +349,7 @@ async def test_ws_api(hass, hass_ws_client, payload):
"language": payload.get("language", hass.config.language), "language": payload.get("language", hass.config.language),
"data": {"code": "no_intent_match"}, "data": {"code": "no_intent_match"},
}, },
"conversation_id": payload.get("conversation_id") or ANY, "conversation_id": None,
} }
@ -560,5 +560,12 @@ async def test_non_default_response(hass, init_components):
agent = await conversation._get_agent(hass) agent = await conversation._get_agent(hass)
assert isinstance(agent, conversation.DefaultAgent) assert isinstance(agent, conversation.DefaultAgent)
result = await agent.async_process("open the front door", Context()) result = await agent.async_process(
conversation.ConversationInput(
text="open the front door",
context=Context(),
conversation_id=None,
language=hass.config.language,
)
)
assert result.response.speech["plain"]["speech"] == "Opened front door" assert result.response.speech["plain"]["speech"] == "Opened front door"