Add support for conversation ID (#28620)
This commit is contained in:
parent
9b5fa2e67c
commit
fadb6a3979
5 changed files with 34 additions and 11 deletions
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
from aiohttp import ClientSession, ClientError
|
from aiohttp import ClientSession, ClientError
|
||||||
|
@ -205,9 +206,11 @@ class AlmondAgent(conversation.AbstractConversationAgent):
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.api = api
|
self.api = api
|
||||||
|
|
||||||
async def async_process(self, text: str) -> intent.IntentResponse:
|
async def async_process(
|
||||||
|
self, text: str, conversation_id: Optional[str] = None
|
||||||
|
) -> intent.IntentResponse:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
response = await self.api.async_converse_text(text)
|
response = await self.api.async_converse_text(text, conversation_id)
|
||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
for message in response["messages"]:
|
for message in response["messages"]:
|
||||||
|
|
|
@ -53,7 +53,7 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent):
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Register the process service."""
|
"""Register the process service."""
|
||||||
|
|
||||||
async def process(hass, text):
|
async def process(hass, text, conversation_id):
|
||||||
"""Process a line of text."""
|
"""Process a line of text."""
|
||||||
agent = hass.data.get(DATA_AGENT)
|
agent = hass.data.get(DATA_AGENT)
|
||||||
|
|
||||||
|
@ -61,14 +61,14 @@ async def async_setup(hass, config):
|
||||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||||
await agent.async_initialize(config)
|
await agent.async_initialize(config)
|
||||||
|
|
||||||
return await agent.async_process(text)
|
return await agent.async_process(text, conversation_id)
|
||||||
|
|
||||||
async def handle_service(service):
|
async def handle_service(service):
|
||||||
"""Parse text into commands."""
|
"""Parse text into commands."""
|
||||||
text = service.data[ATTR_TEXT]
|
text = service.data[ATTR_TEXT]
|
||||||
_LOGGER.debug("Processing: <%s>", text)
|
_LOGGER.debug("Processing: <%s>", text)
|
||||||
try:
|
try:
|
||||||
await process(hass, text)
|
await process(hass, text, service.context.id)
|
||||||
except intent.IntentHandleError as err:
|
except intent.IntentHandleError as err:
|
||||||
_LOGGER.error("Error processing %s: %s", text, err)
|
_LOGGER.error("Error processing %s: %s", text, err)
|
||||||
|
|
||||||
|
@ -91,13 +91,17 @@ class ConversationProcessView(http.HomeAssistantView):
|
||||||
"""Initialize the conversation process view."""
|
"""Initialize the conversation process view."""
|
||||||
self._process = process
|
self._process = process
|
||||||
|
|
||||||
@RequestDataValidator(vol.Schema({vol.Required("text"): str}))
|
@RequestDataValidator(
|
||||||
|
vol.Schema({vol.Required("text"): str, vol.Optional("conversation_id"): str})
|
||||||
|
)
|
||||||
async def post(self, request, data):
|
async def post(self, request, data):
|
||||||
"""Send a request for processing."""
|
"""Send a request for processing."""
|
||||||
hass = request.app["hass"]
|
hass = request.app["hass"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
intent_result = await self._process(hass, data["text"])
|
intent_result = await self._process(
|
||||||
|
hass, data["text"], data.get("conversation_id")
|
||||||
|
)
|
||||||
except intent.IntentHandleError as err:
|
except intent.IntentHandleError as err:
|
||||||
intent_result = intent.IntentResponse()
|
intent_result = intent.IntentResponse()
|
||||||
intent_result.async_set_speech(str(err))
|
intent_result.async_set_speech(str(err))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Agent foundation for conversation integration."""
|
"""Agent foundation for conversation integration."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
|
|
||||||
|
@ -8,5 +9,7 @@ class AbstractConversationAgent(ABC):
|
||||||
"""Abstract conversation agent."""
|
"""Abstract conversation agent."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_process(self, text: str) -> intent.IntentResponse:
|
async def async_process(
|
||||||
|
self, text: str, conversation_id: Optional[str] = None
|
||||||
|
) -> intent.IntentResponse:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Standard conversastion implementation for Home Assistant."""
|
"""Standard conversastion implementation for Home Assistant."""
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from homeassistant import core
|
from homeassistant import core
|
||||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||||
|
@ -107,7 +108,9 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
for intent_type, sentences in UTTERANCES[component].items():
|
for intent_type, sentences in UTTERANCES[component].items():
|
||||||
async_register(self.hass, intent_type, sentences)
|
async_register(self.hass, intent_type, sentences)
|
||||||
|
|
||||||
async def async_process(self, text) -> intent.IntentResponse:
|
async def async_process(
|
||||||
|
self, text: str, conversation_id: Optional[str] = None
|
||||||
|
) -> intent.IntentResponse:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
intents = self.hass.data[DOMAIN]
|
intents = self.hass.data[DOMAIN]
|
||||||
|
|
||||||
|
|
|
@ -266,11 +266,14 @@ async def test_http_api_wrong_data(hass, hass_client):
|
||||||
async def test_custom_agent(hass, hass_client):
|
async def test_custom_agent(hass, hass_client):
|
||||||
"""Test a custom conversation agent."""
|
"""Test a custom conversation agent."""
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
class MyAgent(conversation.AbstractConversationAgent):
|
class MyAgent(conversation.AbstractConversationAgent):
|
||||||
"""Test Agent."""
|
"""Test Agent."""
|
||||||
|
|
||||||
async def async_process(self, text):
|
async def async_process(self, text, conversation_id):
|
||||||
"""Process some text."""
|
"""Process some text."""
|
||||||
|
calls.append((text, conversation_id))
|
||||||
response = intent.IntentResponse()
|
response = intent.IntentResponse()
|
||||||
response.async_set_speech("Test response")
|
response.async_set_speech("Test response")
|
||||||
return response
|
return response
|
||||||
|
@ -281,9 +284,16 @@ async def test_custom_agent(hass, hass_client):
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
resp = await client.post("/api/conversation/process", json={"text": "Test Text"})
|
resp = await client.post(
|
||||||
|
"/api/conversation/process",
|
||||||
|
json={"text": "Test Text", "conversation_id": "test-conv-id"},
|
||||||
|
)
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert await resp.json() == {
|
assert await resp.json() == {
|
||||||
"card": {},
|
"card": {},
|
||||||
"speech": {"plain": {"extra_data": None, "speech": "Test response"}},
|
"speech": {"plain": {"extra_data": None, "speech": "Test response"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0][0] == "Test Text"
|
||||||
|
assert calls[0][1] == "test-conv-id"
|
||||||
|
|
Loading…
Add table
Reference in a new issue