Make conversation and intent context aware (#28965)

* WIP

* LINT
This commit is contained in:
Paulus Schoutsen 2019-11-26 02:30:21 -08:00 committed by Pascal Vizeli
parent 6c6a5c50a5
commit c76f768a82
9 changed files with 82 additions and 57 deletions

View file

@ -10,7 +10,7 @@ from aiohttp import ClientSession, ClientError
from pyalmond import AlmondLocalAuth, AbstractAlmondWebAuth, WebAlmondAPI from pyalmond import AlmondLocalAuth, AbstractAlmondWebAuth, WebAlmondAPI
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant, CoreState from homeassistant.core import HomeAssistant, CoreState, Context
from homeassistant.const import CONF_TYPE, CONF_HOST, EVENT_HOMEASSISTANT_START from homeassistant.const import CONF_TYPE, CONF_HOST, EVENT_HOMEASSISTANT_START
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.auth.const import GROUP_ID_ADMIN from homeassistant.auth.const import GROUP_ID_ADMIN
@ -277,7 +277,7 @@ class AlmondAgent(conversation.AbstractConversationAgent):
return True return True
async def async_process( async def async_process(
self, text: str, conversation_id: Optional[str] = None self, text: str, context: Context, conversation_id: Optional[str] = None
) -> intent.IntentResponse: ) -> intent.IntentResponse:
"""Process a sentence.""" """Process a sentence."""
response = await self.api.async_converse_text(text, conversation_id) response = await self.api.async_converse_text(text, conversation_id)

View file

@ -50,15 +50,6 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent):
hass.data[DATA_AGENT] = agent hass.data[DATA_AGENT] = agent
async def get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
"""Get agent."""
agent = hass.data.get(DATA_AGENT)
if agent is None:
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize(hass.data.get(DATA_CONFIG))
return agent
async def async_setup(hass, config): async def async_setup(hass, config):
"""Register the process service.""" """Register the process service."""
hass.data[DATA_CONFIG] = config hass.data[DATA_CONFIG] = config
@ -67,8 +58,9 @@ async def async_setup(hass, config):
"""Parse text into commands.""" """Parse text into commands."""
text = service.data[ATTR_TEXT] text = service.data[ATTR_TEXT]
_LOGGER.debug("Processing: <%s>", text) _LOGGER.debug("Processing: <%s>", text)
agent = await _get_agent(hass)
try: try:
await process(hass, text, service.context.id) await agent.async_process(text, service.context)
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
_LOGGER.error("Error processing %s: %s", text, err) _LOGGER.error("Error processing %s: %s", text, err)
@ -84,27 +76,6 @@ async def async_setup(hass, config):
return True return True
async def process(hass: core.HomeAssistant, text: str, conversation_id: str):
"""Process text and get intent."""
agent = await get_agent(hass)
return await agent.async_process(text, conversation_id)
async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str):
"""Process text and get intent."""
try:
intent_result = await process(hass, text, conversation_id)
except intent.IntentHandleError as err:
intent_result = intent.IntentResponse()
intent_result.async_set_speech(str(err))
if intent_result is None:
intent_result = intent.IntentResponse()
intent_result.async_set_speech("Sorry, I didn't understand that")
return intent_result
@websocket_api.async_response @websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{"type": "conversation/process", "text": str, vol.Optional("conversation_id"): str} {"type": "conversation/process", "text": str, vol.Optional("conversation_id"): str}
@ -112,7 +83,10 @@ async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str):
async def websocket_process(hass, connection, msg): async def websocket_process(hass, connection, msg):
"""Process text.""" """Process text."""
connection.send_result( connection.send_result(
msg["id"], await get_intent(hass, msg["text"], msg.get("conversation_id")) msg["id"],
await _async_converse(
hass, msg["text"], msg.get("conversation_id"), connection.context(msg)
),
) )
@ -120,7 +94,7 @@ async def websocket_process(hass, connection, msg):
@websocket_api.websocket_command({"type": "conversation/agent/info"}) @websocket_api.websocket_command({"type": "conversation/agent/info"})
async def websocket_get_agent_info(hass, connection, msg): async def websocket_get_agent_info(hass, connection, msg):
"""Do we need onboarding.""" """Do we need onboarding."""
agent = await get_agent(hass) agent = await _get_agent(hass)
connection.send_result( connection.send_result(
msg["id"], msg["id"],
@ -135,7 +109,7 @@ async def websocket_get_agent_info(hass, connection, msg):
@websocket_api.websocket_command({"type": "conversation/onboarding/set", "shown": bool}) @websocket_api.websocket_command({"type": "conversation/onboarding/set", "shown": bool})
async def websocket_set_onboarding(hass, connection, msg): async def websocket_set_onboarding(hass, connection, msg):
"""Set onboarding status.""" """Set onboarding status."""
agent = await get_agent(hass) agent = await _get_agent(hass)
success = await agent.async_set_onboarding(msg["shown"]) success = await agent.async_set_onboarding(msg["shown"])
@ -157,8 +131,9 @@ class ConversationProcessView(http.HomeAssistantView):
async def post(self, request, data): async def post(self, request, data):
"""Send a request for processing.""" """Send a request for processing."""
hass = request.app["hass"] hass = request.app["hass"]
intent_result = await get_intent(
hass, data["text"], data.get("conversation_id") intent_result = await _async_converse(
hass, data["text"], data.get("conversation_id"), self.context(request)
) )
return self.json(intent_result) return self.json(intent_result)
@ -188,7 +163,7 @@ class ConversationHandleView(http.HomeAssistantView):
key: {"value": value} for key, value in data.get("data", {}).items() key: {"value": value} for key, value in data.get("data", {}).items()
} }
intent_result = await intent.async_handle( intent_result = await intent.async_handle(
hass, DOMAIN, intent_name, slots, "" hass, DOMAIN, intent_name, slots, "", self.context(request)
) )
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
intent_result = intent.IntentResponse() intent_result = intent.IntentResponse()
@ -199,3 +174,30 @@ class ConversationHandleView(http.HomeAssistantView):
intent_result.async_set_speech("Sorry, I couldn't handle that") intent_result.async_set_speech("Sorry, I couldn't handle that")
return self.json(intent_result) return self.json(intent_result)
async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
"""Get the active conversation agent."""
agent = hass.data.get(DATA_AGENT)
if agent is None:
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize(hass.data.get(DATA_CONFIG))
return agent
async def _async_converse(
hass: core.HomeAssistant, text: str, conversation_id: str, context: core.Context
) -> intent.IntentResponse:
"""Process text and get intent."""
agent = await _get_agent(hass)
try:
intent_result = await agent.async_process(text, context, conversation_id)
except intent.IntentHandleError as err:
intent_result = intent.IntentResponse()
intent_result.async_set_speech(str(err))
if intent_result is None:
intent_result = intent.IntentResponse()
intent_result.async_set_speech("Sorry, I didn't understand that")
return intent_result

View file

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from homeassistant.core import Context
from homeassistant.helpers import intent from homeassistant.helpers import intent
@ -23,6 +24,6 @@ class AbstractConversationAgent(ABC):
@abstractmethod @abstractmethod
async def async_process( async def async_process(
self, text: str, conversation_id: Optional[str] = None self, text: str, context: Context, conversation_id: Optional[str] = None
) -> intent.IntentResponse: ) -> intent.IntentResponse:
"""Process a sentence.""" """Process a sentence."""

View file

@ -109,7 +109,7 @@ class DefaultAgent(AbstractConversationAgent):
async_register(self.hass, intent_type, sentences) async_register(self.hass, intent_type, sentences)
async def async_process( async def async_process(
self, text: str, conversation_id: Optional[str] = None self, text: str, context: core.Context, conversation_id: Optional[str] = None
) -> intent.IntentResponse: ) -> intent.IntentResponse:
"""Process a sentence.""" """Process a sentence."""
intents = self.hass.data[DOMAIN] intents = self.hass.data[DOMAIN]
@ -127,4 +127,5 @@ class DefaultAgent(AbstractConversationAgent):
intent_type, intent_type,
{key: {"value": value} for key, value in match.groupdict().items()}, {key: {"value": value} for key, value in match.groupdict().items()},
text, text,
context,
) )

View file

@ -34,8 +34,8 @@ class HomeAssistantView:
requires_auth = True requires_auth = True
cors_allowed = False cors_allowed = False
# pylint: disable=no-self-use @staticmethod
def context(self, request): def context(request):
"""Generate a context from a request.""" """Generate a context from a request."""
user = request.get("hass_user") user = request.get("hass_user")
if user is None: if user is None:
@ -43,7 +43,8 @@ class HomeAssistantView:
return Context(user_id=user.id) return Context(user_id=user.id)
def json(self, result, status_code=200, headers=None): @staticmethod
def json(result, status_code=200, headers=None):
"""Return a JSON response.""" """Return a JSON response."""
try: try:
msg = json.dumps( msg = json.dumps(

View file

@ -80,7 +80,9 @@ class ScriptIntentHandler(intent.IntentHandler):
if action is not None: if action is not None:
if is_async_action: if is_async_action:
intent_obj.hass.async_create_task(action.async_run(slots)) intent_obj.hass.async_create_task(
action.async_run(slots, intent_obj.context)
)
else: else:
await action.async_run(slots) await action.async_run(slots)

View file

@ -232,7 +232,9 @@ class SetIntentHandler(intent.IntentHandler):
service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"] service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"]
speech_parts.append("{}% brightness".format(slots["brightness"]["value"])) speech_parts.append("{}% brightness".format(slots["brightness"]["value"]))
await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, service_data) await hass.services.async_call(
DOMAIN, SERVICE_TURN_ON, service_data, context=intent_obj.context
)
response = intent_obj.create_response() response = intent_obj.create_response()

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterable, Optional
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_SUPPORTED_FEATURES from homeassistant.const import ATTR_SUPPORTED_FEATURES
from homeassistant.core import callback, State, T from homeassistant.core import callback, State, T, Context
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
@ -53,6 +53,7 @@ async def async_handle(
intent_type: str, intent_type: str,
slots: Optional[_SlotsType] = None, slots: Optional[_SlotsType] = None,
text_input: Optional[str] = None, text_input: Optional[str] = None,
context: Optional[Context] = None,
) -> "IntentResponse": ) -> "IntentResponse":
"""Handle an intent.""" """Handle an intent."""
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type) handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type)
@ -60,7 +61,10 @@ async def async_handle(
if handler is None: if handler is None:
raise UnknownIntent(f"Unknown intent {intent_type}") raise UnknownIntent(f"Unknown intent {intent_type}")
intent = Intent(hass, platform, intent_type, slots or {}, text_input) if context is None:
context = Context()
intent = Intent(hass, platform, intent_type, slots or {}, text_input, context)
try: try:
_LOGGER.info("Triggering intent handler %s", handler) _LOGGER.info("Triggering intent handler %s", handler)
@ -196,7 +200,10 @@ class ServiceIntentHandler(IntentHandler):
state = async_match_state(hass, slots["name"]["value"]) state = async_match_state(hass, slots["name"]["value"])
await hass.services.async_call( await hass.services.async_call(
self.domain, self.service, {ATTR_ENTITY_ID: state.entity_id} self.domain,
self.service,
{ATTR_ENTITY_ID: state.entity_id},
context=intent_obj.context,
) )
response = intent_obj.create_response() response = intent_obj.create_response()
@ -207,7 +214,7 @@ class ServiceIntentHandler(IntentHandler):
class Intent: class Intent:
"""Hold the intent.""" """Hold the intent."""
__slots__ = ["hass", "platform", "intent_type", "slots", "text_input"] __slots__ = ["hass", "platform", "intent_type", "slots", "text_input", "context"]
def __init__( def __init__(
self, self,
@ -216,6 +223,7 @@ class Intent:
intent_type: str, intent_type: str,
slots: _SlotsType, slots: _SlotsType,
text_input: Optional[str], text_input: Optional[str],
context: Context,
) -> None: ) -> None:
"""Initialize an intent.""" """Initialize an intent."""
self.hass = hass self.hass = hass
@ -223,6 +231,7 @@ class Intent:
self.intent_type = intent_type self.intent_type = intent_type
self.slots = slots self.slots = slots
self.text_input = text_input self.text_input = text_input
self.context = context
@callback @callback
def create_response(self) -> "IntentResponse": def create_response(self) -> "IntentResponse":

View file

@ -2,7 +2,7 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import pytest import pytest
from homeassistant.core import DOMAIN as HASS_DOMAIN from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.cover import SERVICE_OPEN_COVER from homeassistant.components.cover import SERVICE_OPEN_COVER
@ -25,10 +25,13 @@ async def test_calling_intent(hass):
) )
assert result assert result
context = Context()
await hass.services.async_call( await hass.services.async_call(
"conversation", "conversation",
"process", "process",
{conversation.ATTR_TEXT: "I would like the Grolsch beer"}, {conversation.ATTR_TEXT: "I would like the Grolsch beer"},
context=context,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -38,6 +41,7 @@ async def test_calling_intent(hass):
assert intent.intent_type == "OrderBeer" assert intent.intent_type == "OrderBeer"
assert intent.slots == {"type": {"value": "Grolsch"}} assert intent.slots == {"type": {"value": "Grolsch"}}
assert intent.text_input == "I would like the Grolsch beer" assert intent.text_input == "I would like the Grolsch beer"
assert intent.context is context
async def test_register_before_setup(hass): async def test_register_before_setup(hass):
@ -80,7 +84,7 @@ async def test_register_before_setup(hass):
assert intent.text_input == "I would like the Grolsch beer" assert intent.text_input == "I would like the Grolsch beer"
async def test_http_processing_intent(hass, hass_client): async def test_http_processing_intent(hass, hass_client, hass_admin_user):
"""Test processing intent via HTTP API.""" """Test processing intent via HTTP API."""
class TestIntentHandler(intent.IntentHandler): class TestIntentHandler(intent.IntentHandler):
@ -90,6 +94,7 @@ async def test_http_processing_intent(hass, hass_client):
async def async_handle(self, intent): async def async_handle(self, intent):
"""Handle the intent.""" """Handle the intent."""
assert intent.context.user_id == hass_admin_user.id
response = intent.create_response() response = intent.create_response()
response.async_set_speech( response.async_set_speech(
"I've ordered a {}!".format(intent.slots["type"]["value"]) "I've ordered a {}!".format(intent.slots["type"]["value"])
@ -124,7 +129,7 @@ async def test_http_processing_intent(hass, hass_client):
} }
async def test_http_handle_intent(hass, hass_client): async def test_http_handle_intent(hass, hass_client, hass_admin_user):
"""Test handle intent via HTTP API.""" """Test handle intent via HTTP API."""
class TestIntentHandler(intent.IntentHandler): class TestIntentHandler(intent.IntentHandler):
@ -134,6 +139,7 @@ async def test_http_handle_intent(hass, hass_client):
async def async_handle(self, intent): async def async_handle(self, intent):
"""Handle the intent.""" """Handle the intent."""
assert intent.context.user_id == hass_admin_user.id
response = intent.create_response() response = intent.create_response()
response.async_set_speech( response.async_set_speech(
"I've ordered a {}!".format(intent.slots["type"]["value"]) "I've ordered a {}!".format(intent.slots["type"]["value"])
@ -308,7 +314,7 @@ async def test_http_api_wrong_data(hass, hass_client):
assert resp.status == 400 assert resp.status == 400
async def test_custom_agent(hass, hass_client): async def test_custom_agent(hass, hass_client, hass_admin_user):
"""Test a custom conversation agent.""" """Test a custom conversation agent."""
calls = [] calls = []
@ -316,9 +322,9 @@ async def test_custom_agent(hass, hass_client):
class MyAgent(conversation.AbstractConversationAgent): class MyAgent(conversation.AbstractConversationAgent):
"""Test Agent.""" """Test Agent."""
async def async_process(self, text, conversation_id): async def async_process(self, text, context, conversation_id):
"""Process some text.""" """Process some text."""
calls.append((text, conversation_id)) calls.append((text, context, conversation_id))
response = intent.IntentResponse() response = intent.IntentResponse()
response.async_set_speech("Test response") response.async_set_speech("Test response")
return response return response
@ -341,4 +347,5 @@ async def test_custom_agent(hass, hass_client):
assert len(calls) == 1 assert len(calls) == 1
assert calls[0][0] == "Test Text" assert calls[0][0] == "Test Text"
assert calls[0][1] == "test-conv-id" assert calls[0][1].user_id == hass_admin_user.id
assert calls[0][2] == "test-conv-id"