Add support for conversation ID (#28620)

This commit is contained in:
Paulus Schoutsen 2019-11-07 12:21:12 -08:00 committed by GitHub
parent 9b5fa2e67c
commit fadb6a3979
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 11 deletions

View file

@ -3,6 +3,7 @@ import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
import time import time
from typing import Optional
import async_timeout import async_timeout
from aiohttp import ClientSession, ClientError from aiohttp import ClientSession, ClientError
@ -205,9 +206,11 @@ class AlmondAgent(conversation.AbstractConversationAgent):
"""Initialize the agent.""" """Initialize the agent."""
self.api = api self.api = api
async def async_process(self, text: str) -> intent.IntentResponse: async def async_process(
self, text: str, conversation_id: Optional[str] = None
) -> intent.IntentResponse:
"""Process a sentence.""" """Process a sentence."""
response = await self.api.async_converse_text(text) response = await self.api.async_converse_text(text, conversation_id)
buffer = "" buffer = ""
for message in response["messages"]: for message in response["messages"]:

View file

@ -53,7 +53,7 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent):
async def async_setup(hass, config): async def async_setup(hass, config):
"""Register the process service.""" """Register the process service."""
async def process(hass, text): async def process(hass, text, conversation_id):
"""Process a line of text.""" """Process a line of text."""
agent = hass.data.get(DATA_AGENT) agent = hass.data.get(DATA_AGENT)
@ -61,14 +61,14 @@ async def async_setup(hass, config):
agent = hass.data[DATA_AGENT] = DefaultAgent(hass) agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize(config) await agent.async_initialize(config)
return await agent.async_process(text) return await agent.async_process(text, conversation_id)
async def handle_service(service): async def handle_service(service):
"""Parse text into commands.""" """Parse text into commands."""
text = service.data[ATTR_TEXT] text = service.data[ATTR_TEXT]
_LOGGER.debug("Processing: <%s>", text) _LOGGER.debug("Processing: <%s>", text)
try: try:
await process(hass, text) await process(hass, text, service.context.id)
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
_LOGGER.error("Error processing %s: %s", text, err) _LOGGER.error("Error processing %s: %s", text, err)
@ -91,13 +91,17 @@ class ConversationProcessView(http.HomeAssistantView):
"""Initialize the conversation process view.""" """Initialize the conversation process view."""
self._process = process self._process = process
@RequestDataValidator(vol.Schema({vol.Required("text"): str})) @RequestDataValidator(
vol.Schema({vol.Required("text"): str, vol.Optional("conversation_id"): str})
)
async def post(self, request, data): async def post(self, request, data):
"""Send a request for processing.""" """Send a request for processing."""
hass = request.app["hass"] hass = request.app["hass"]
try: try:
intent_result = await self._process(hass, data["text"]) intent_result = await self._process(
hass, data["text"], data.get("conversation_id")
)
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
intent_result = intent.IntentResponse() intent_result = intent.IntentResponse()
intent_result.async_set_speech(str(err)) intent_result.async_set_speech(str(err))

View file

@ -1,5 +1,6 @@
"""Agent foundation for conversation integration.""" """Agent foundation for conversation integration."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
from homeassistant.helpers import intent from homeassistant.helpers import intent
@ -8,5 +9,7 @@ class AbstractConversationAgent(ABC):
"""Abstract conversation agent.""" """Abstract conversation agent."""
@abstractmethod @abstractmethod
async def async_process(self, text: str) -> intent.IntentResponse: async def async_process(
self, text: str, conversation_id: Optional[str] = None
) -> intent.IntentResponse:
"""Process a sentence.""" """Process a sentence."""

View file

@ -1,6 +1,7 @@
"""Standard conversastion implementation for Home Assistant.""" """Standard conversastion implementation for Home Assistant."""
import logging import logging
import re import re
from typing import Optional
from homeassistant import core from homeassistant import core
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
@ -107,7 +108,9 @@ class DefaultAgent(AbstractConversationAgent):
for intent_type, sentences in UTTERANCES[component].items(): for intent_type, sentences in UTTERANCES[component].items():
async_register(self.hass, intent_type, sentences) async_register(self.hass, intent_type, sentences)
async def async_process(self, text) -> intent.IntentResponse: async def async_process(
self, text: str, conversation_id: Optional[str] = None
) -> intent.IntentResponse:
"""Process a sentence.""" """Process a sentence."""
intents = self.hass.data[DOMAIN] intents = self.hass.data[DOMAIN]

View file

@ -266,11 +266,14 @@ async def test_http_api_wrong_data(hass, hass_client):
async def test_custom_agent(hass, hass_client): async def test_custom_agent(hass, hass_client):
"""Test a custom conversation agent.""" """Test a custom conversation agent."""
calls = []
class MyAgent(conversation.AbstractConversationAgent): class MyAgent(conversation.AbstractConversationAgent):
"""Test Agent.""" """Test Agent."""
async def async_process(self, text): async def async_process(self, text, conversation_id):
"""Process some text.""" """Process some text."""
calls.append((text, conversation_id))
response = intent.IntentResponse() response = intent.IntentResponse()
response.async_set_speech("Test response") response.async_set_speech("Test response")
return response return response
@ -281,9 +284,16 @@ async def test_custom_agent(hass, hass_client):
client = await hass_client() client = await hass_client()
resp = await client.post("/api/conversation/process", json={"text": "Test Text"}) resp = await client.post(
"/api/conversation/process",
json={"text": "Test Text", "conversation_id": "test-conv-id"},
)
assert resp.status == 200 assert resp.status == 200
assert await resp.json() == { assert await resp.json() == {
"card": {}, "card": {},
"speech": {"plain": {"extra_data": None, "speech": "Test response"}}, "speech": {"plain": {"extra_data": None, "speech": "Test response"}},
} }
assert len(calls) == 1
assert calls[0][0] == "Test Text"
assert calls[0][1] == "test-conv-id"