parent
6c6a5c50a5
commit
c76f768a82
9 changed files with 82 additions and 57 deletions
|
@ -10,7 +10,7 @@ from aiohttp import ClientSession, ClientError
|
|||
from pyalmond import AlmondLocalAuth, AbstractAlmondWebAuth, WebAlmondAPI
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, CoreState
|
||||
from homeassistant.core import HomeAssistant, CoreState, Context
|
||||
from homeassistant.const import CONF_TYPE, CONF_HOST, EVENT_HOMEASSISTANT_START
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
|
@ -277,7 +277,7 @@ class AlmondAgent(conversation.AbstractConversationAgent):
|
|||
return True
|
||||
|
||||
async def async_process(
|
||||
self, text: str, conversation_id: Optional[str] = None
|
||||
self, text: str, context: Context, conversation_id: Optional[str] = None
|
||||
) -> intent.IntentResponse:
|
||||
"""Process a sentence."""
|
||||
response = await self.api.async_converse_text(text, conversation_id)
|
||||
|
|
|
@ -50,15 +50,6 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent):
|
|||
hass.data[DATA_AGENT] = agent
|
||||
|
||||
|
||||
async def get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
||||
"""Get agent."""
|
||||
agent = hass.data.get(DATA_AGENT)
|
||||
if agent is None:
|
||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
||||
return agent
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Register the process service."""
|
||||
hass.data[DATA_CONFIG] = config
|
||||
|
@ -67,8 +58,9 @@ async def async_setup(hass, config):
|
|||
"""Parse text into commands."""
|
||||
text = service.data[ATTR_TEXT]
|
||||
_LOGGER.debug("Processing: <%s>", text)
|
||||
agent = await _get_agent(hass)
|
||||
try:
|
||||
await process(hass, text, service.context.id)
|
||||
await agent.async_process(text, service.context)
|
||||
except intent.IntentHandleError as err:
|
||||
_LOGGER.error("Error processing %s: %s", text, err)
|
||||
|
||||
|
@ -84,27 +76,6 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
async def process(hass: core.HomeAssistant, text: str, conversation_id: str):
|
||||
"""Process text and get intent."""
|
||||
agent = await get_agent(hass)
|
||||
return await agent.async_process(text, conversation_id)
|
||||
|
||||
|
||||
async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str):
|
||||
"""Process text and get intent."""
|
||||
try:
|
||||
intent_result = await process(hass, text, conversation_id)
|
||||
except intent.IntentHandleError as err:
|
||||
intent_result = intent.IntentResponse()
|
||||
intent_result.async_set_speech(str(err))
|
||||
|
||||
if intent_result is None:
|
||||
intent_result = intent.IntentResponse()
|
||||
intent_result.async_set_speech("Sorry, I didn't understand that")
|
||||
|
||||
return intent_result
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{"type": "conversation/process", "text": str, vol.Optional("conversation_id"): str}
|
||||
|
@ -112,7 +83,10 @@ async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str):
|
|||
async def websocket_process(hass, connection, msg):
|
||||
"""Process text."""
|
||||
connection.send_result(
|
||||
msg["id"], await get_intent(hass, msg["text"], msg.get("conversation_id"))
|
||||
msg["id"],
|
||||
await _async_converse(
|
||||
hass, msg["text"], msg.get("conversation_id"), connection.context(msg)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -120,7 +94,7 @@ async def websocket_process(hass, connection, msg):
|
|||
@websocket_api.websocket_command({"type": "conversation/agent/info"})
|
||||
async def websocket_get_agent_info(hass, connection, msg):
|
||||
"""Do we need onboarding."""
|
||||
agent = await get_agent(hass)
|
||||
agent = await _get_agent(hass)
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
|
@ -135,7 +109,7 @@ async def websocket_get_agent_info(hass, connection, msg):
|
|||
@websocket_api.websocket_command({"type": "conversation/onboarding/set", "shown": bool})
|
||||
async def websocket_set_onboarding(hass, connection, msg):
|
||||
"""Set onboarding status."""
|
||||
agent = await get_agent(hass)
|
||||
agent = await _get_agent(hass)
|
||||
|
||||
success = await agent.async_set_onboarding(msg["shown"])
|
||||
|
||||
|
@ -157,8 +131,9 @@ class ConversationProcessView(http.HomeAssistantView):
|
|||
async def post(self, request, data):
|
||||
"""Send a request for processing."""
|
||||
hass = request.app["hass"]
|
||||
intent_result = await get_intent(
|
||||
hass, data["text"], data.get("conversation_id")
|
||||
|
||||
intent_result = await _async_converse(
|
||||
hass, data["text"], data.get("conversation_id"), self.context(request)
|
||||
)
|
||||
|
||||
return self.json(intent_result)
|
||||
|
@ -188,7 +163,7 @@ class ConversationHandleView(http.HomeAssistantView):
|
|||
key: {"value": value} for key, value in data.get("data", {}).items()
|
||||
}
|
||||
intent_result = await intent.async_handle(
|
||||
hass, DOMAIN, intent_name, slots, ""
|
||||
hass, DOMAIN, intent_name, slots, "", self.context(request)
|
||||
)
|
||||
except intent.IntentHandleError as err:
|
||||
intent_result = intent.IntentResponse()
|
||||
|
@ -199,3 +174,30 @@ class ConversationHandleView(http.HomeAssistantView):
|
|||
intent_result.async_set_speech("Sorry, I couldn't handle that")
|
||||
|
||||
return self.json(intent_result)
|
||||
|
||||
|
||||
async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
||||
"""Get the active conversation agent."""
|
||||
agent = hass.data.get(DATA_AGENT)
|
||||
if agent is None:
|
||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
||||
return agent
|
||||
|
||||
|
||||
async def _async_converse(
|
||||
hass: core.HomeAssistant, text: str, conversation_id: str, context: core.Context
|
||||
) -> intent.IntentResponse:
|
||||
"""Process text and get intent."""
|
||||
agent = await _get_agent(hass)
|
||||
try:
|
||||
intent_result = await agent.async_process(text, context, conversation_id)
|
||||
except intent.IntentHandleError as err:
|
||||
intent_result = intent.IntentResponse()
|
||||
intent_result.async_set_speech(str(err))
|
||||
|
||||
if intent_result is None:
|
||||
intent_result = intent.IntentResponse()
|
||||
intent_result.async_set_speech("Sorry, I didn't understand that")
|
||||
|
||||
return intent_result
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import intent
|
||||
|
||||
|
||||
|
@ -23,6 +24,6 @@ class AbstractConversationAgent(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def async_process(
|
||||
self, text: str, conversation_id: Optional[str] = None
|
||||
self, text: str, context: Context, conversation_id: Optional[str] = None
|
||||
) -> intent.IntentResponse:
|
||||
"""Process a sentence."""
|
||||
|
|
|
@ -109,7 +109,7 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
async_register(self.hass, intent_type, sentences)
|
||||
|
||||
async def async_process(
|
||||
self, text: str, conversation_id: Optional[str] = None
|
||||
self, text: str, context: core.Context, conversation_id: Optional[str] = None
|
||||
) -> intent.IntentResponse:
|
||||
"""Process a sentence."""
|
||||
intents = self.hass.data[DOMAIN]
|
||||
|
@ -127,4 +127,5 @@ class DefaultAgent(AbstractConversationAgent):
|
|||
intent_type,
|
||||
{key: {"value": value} for key, value in match.groupdict().items()},
|
||||
text,
|
||||
context,
|
||||
)
|
||||
|
|
|
@ -34,8 +34,8 @@ class HomeAssistantView:
|
|||
requires_auth = True
|
||||
cors_allowed = False
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def context(self, request):
|
||||
@staticmethod
|
||||
def context(request):
|
||||
"""Generate a context from a request."""
|
||||
user = request.get("hass_user")
|
||||
if user is None:
|
||||
|
@ -43,7 +43,8 @@ class HomeAssistantView:
|
|||
|
||||
return Context(user_id=user.id)
|
||||
|
||||
def json(self, result, status_code=200, headers=None):
|
||||
@staticmethod
|
||||
def json(result, status_code=200, headers=None):
|
||||
"""Return a JSON response."""
|
||||
try:
|
||||
msg = json.dumps(
|
||||
|
|
|
@ -80,7 +80,9 @@ class ScriptIntentHandler(intent.IntentHandler):
|
|||
|
||||
if action is not None:
|
||||
if is_async_action:
|
||||
intent_obj.hass.async_create_task(action.async_run(slots))
|
||||
intent_obj.hass.async_create_task(
|
||||
action.async_run(slots, intent_obj.context)
|
||||
)
|
||||
else:
|
||||
await action.async_run(slots)
|
||||
|
||||
|
|
|
@ -232,7 +232,9 @@ class SetIntentHandler(intent.IntentHandler):
|
|||
service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"]
|
||||
speech_parts.append("{}% brightness".format(slots["brightness"]["value"]))
|
||||
|
||||
await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, service_data)
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_ON, service_data, context=intent_obj.context
|
||||
)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterable, Optional
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES
|
||||
from homeassistant.core import callback, State, T
|
||||
from homeassistant.core import callback, State, T, Context
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
@ -53,6 +53,7 @@ async def async_handle(
|
|||
intent_type: str,
|
||||
slots: Optional[_SlotsType] = None,
|
||||
text_input: Optional[str] = None,
|
||||
context: Optional[Context] = None,
|
||||
) -> "IntentResponse":
|
||||
"""Handle an intent."""
|
||||
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type)
|
||||
|
@ -60,7 +61,10 @@ async def async_handle(
|
|||
if handler is None:
|
||||
raise UnknownIntent(f"Unknown intent {intent_type}")
|
||||
|
||||
intent = Intent(hass, platform, intent_type, slots or {}, text_input)
|
||||
if context is None:
|
||||
context = Context()
|
||||
|
||||
intent = Intent(hass, platform, intent_type, slots or {}, text_input, context)
|
||||
|
||||
try:
|
||||
_LOGGER.info("Triggering intent handler %s", handler)
|
||||
|
@ -196,7 +200,10 @@ class ServiceIntentHandler(IntentHandler):
|
|||
state = async_match_state(hass, slots["name"]["value"])
|
||||
|
||||
await hass.services.async_call(
|
||||
self.domain, self.service, {ATTR_ENTITY_ID: state.entity_id}
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
|
@ -207,7 +214,7 @@ class ServiceIntentHandler(IntentHandler):
|
|||
class Intent:
|
||||
"""Hold the intent."""
|
||||
|
||||
__slots__ = ["hass", "platform", "intent_type", "slots", "text_input"]
|
||||
__slots__ = ["hass", "platform", "intent_type", "slots", "text_input", "context"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -216,6 +223,7 @@ class Intent:
|
|||
intent_type: str,
|
||||
slots: _SlotsType,
|
||||
text_input: Optional[str],
|
||||
context: Context,
|
||||
) -> None:
|
||||
"""Initialize an intent."""
|
||||
self.hass = hass
|
||||
|
@ -223,6 +231,7 @@ class Intent:
|
|||
self.intent_type = intent_type
|
||||
self.slots = slots
|
||||
self.text_input = text_input
|
||||
self.context = context
|
||||
|
||||
@callback
|
||||
def create_response(self) -> "IntentResponse":
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# pylint: disable=protected-access
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import DOMAIN as HASS_DOMAIN
|
||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cover import SERVICE_OPEN_COVER
|
||||
|
@ -25,10 +25,13 @@ async def test_calling_intent(hass):
|
|||
)
|
||||
assert result
|
||||
|
||||
context = Context()
|
||||
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{conversation.ATTR_TEXT: "I would like the Grolsch beer"},
|
||||
context=context,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
@ -38,6 +41,7 @@ async def test_calling_intent(hass):
|
|||
assert intent.intent_type == "OrderBeer"
|
||||
assert intent.slots == {"type": {"value": "Grolsch"}}
|
||||
assert intent.text_input == "I would like the Grolsch beer"
|
||||
assert intent.context is context
|
||||
|
||||
|
||||
async def test_register_before_setup(hass):
|
||||
|
@ -80,7 +84,7 @@ async def test_register_before_setup(hass):
|
|||
assert intent.text_input == "I would like the Grolsch beer"
|
||||
|
||||
|
||||
async def test_http_processing_intent(hass, hass_client):
|
||||
async def test_http_processing_intent(hass, hass_client, hass_admin_user):
|
||||
"""Test processing intent via HTTP API."""
|
||||
|
||||
class TestIntentHandler(intent.IntentHandler):
|
||||
|
@ -90,6 +94,7 @@ async def test_http_processing_intent(hass, hass_client):
|
|||
|
||||
async def async_handle(self, intent):
|
||||
"""Handle the intent."""
|
||||
assert intent.context.user_id == hass_admin_user.id
|
||||
response = intent.create_response()
|
||||
response.async_set_speech(
|
||||
"I've ordered a {}!".format(intent.slots["type"]["value"])
|
||||
|
@ -124,7 +129,7 @@ async def test_http_processing_intent(hass, hass_client):
|
|||
}
|
||||
|
||||
|
||||
async def test_http_handle_intent(hass, hass_client):
|
||||
async def test_http_handle_intent(hass, hass_client, hass_admin_user):
|
||||
"""Test handle intent via HTTP API."""
|
||||
|
||||
class TestIntentHandler(intent.IntentHandler):
|
||||
|
@ -134,6 +139,7 @@ async def test_http_handle_intent(hass, hass_client):
|
|||
|
||||
async def async_handle(self, intent):
|
||||
"""Handle the intent."""
|
||||
assert intent.context.user_id == hass_admin_user.id
|
||||
response = intent.create_response()
|
||||
response.async_set_speech(
|
||||
"I've ordered a {}!".format(intent.slots["type"]["value"])
|
||||
|
@ -308,7 +314,7 @@ async def test_http_api_wrong_data(hass, hass_client):
|
|||
assert resp.status == 400
|
||||
|
||||
|
||||
async def test_custom_agent(hass, hass_client):
|
||||
async def test_custom_agent(hass, hass_client, hass_admin_user):
|
||||
"""Test a custom conversation agent."""
|
||||
|
||||
calls = []
|
||||
|
@ -316,9 +322,9 @@ async def test_custom_agent(hass, hass_client):
|
|||
class MyAgent(conversation.AbstractConversationAgent):
|
||||
"""Test Agent."""
|
||||
|
||||
async def async_process(self, text, conversation_id):
|
||||
async def async_process(self, text, context, conversation_id):
|
||||
"""Process some text."""
|
||||
calls.append((text, conversation_id))
|
||||
calls.append((text, context, conversation_id))
|
||||
response = intent.IntentResponse()
|
||||
response.async_set_speech("Test response")
|
||||
return response
|
||||
|
@ -341,4 +347,5 @@ async def test_custom_agent(hass, hass_client):
|
|||
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0] == "Test Text"
|
||||
assert calls[0][1] == "test-conv-id"
|
||||
assert calls[0][1].user_id == hass_admin_user.id
|
||||
assert calls[0][2] == "test-conv-id"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue