Add attribution and onboarding commands to conversation and Almond (#28621)
* Add attribution and onboarding commands to conversation and Almond * False -> None * Comments * Update __init__.py * Comments + websocket for convert * Lint
This commit is contained in:
parent
4435b3a5c9
commit
28c6837f00
3 changed files with 128 additions and 33 deletions
|
@ -10,6 +10,7 @@ from aiohttp import ClientSession, ClientError
|
||||||
from pyalmond import AlmondLocalAuth, AbstractAlmondWebAuth, WebAlmondAPI
|
from pyalmond import AlmondLocalAuth, AbstractAlmondWebAuth, WebAlmondAPI
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import core
|
||||||
from homeassistant.const import CONF_TYPE, CONF_HOST
|
from homeassistant.const import CONF_TYPE, CONF_HOST
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryNotReady
|
||||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||||
|
@ -95,9 +96,9 @@ async def async_setup(hass, config):
|
||||||
async def async_setup_entry(hass, entry):
|
async def async_setup_entry(hass, entry):
|
||||||
"""Set up Almond config entry."""
|
"""Set up Almond config entry."""
|
||||||
websession = aiohttp_client.async_get_clientsession(hass)
|
websession = aiohttp_client.async_get_clientsession(hass)
|
||||||
|
|
||||||
if entry.data["type"] == TYPE_LOCAL:
|
if entry.data["type"] == TYPE_LOCAL:
|
||||||
auth = AlmondLocalAuth(entry.data["host"], websession)
|
auth = AlmondLocalAuth(entry.data["host"], websession)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# OAuth2
|
# OAuth2
|
||||||
implementation = await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
implementation = await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||||
|
@ -109,7 +110,7 @@ async def async_setup_entry(hass, entry):
|
||||||
auth = AlmondOAuth(entry.data["host"], websession, oauth_session)
|
auth = AlmondOAuth(entry.data["host"], websession, oauth_session)
|
||||||
|
|
||||||
api = WebAlmondAPI(auth)
|
api = WebAlmondAPI(auth)
|
||||||
agent = AlmondAgent(api)
|
agent = AlmondAgent(hass, api, entry)
|
||||||
|
|
||||||
# Hass.io does its own configuration of Almond.
|
# Hass.io does its own configuration of Almond.
|
||||||
if entry.data.get("is_hassio") or entry.data["type"] != TYPE_LOCAL:
|
if entry.data.get("is_hassio") or entry.data["type"] != TYPE_LOCAL:
|
||||||
|
@ -202,9 +203,39 @@ class AlmondOAuth(AbstractAlmondWebAuth):
|
||||||
class AlmondAgent(conversation.AbstractConversationAgent):
|
class AlmondAgent(conversation.AbstractConversationAgent):
|
||||||
"""Almond conversation agent."""
|
"""Almond conversation agent."""
|
||||||
|
|
||||||
def __init__(self, api: WebAlmondAPI):
|
def __init__(self, hass: core.HomeAssistant, api: WebAlmondAPI, entry):
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
|
self.hass = hass
|
||||||
self.api = api
|
self.api = api
|
||||||
|
self.entry = entry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attribution(self):
|
||||||
|
"""Return the attribution."""
|
||||||
|
return {"name": "Powered by Almond", "url": "https://almond.stanford.edu/"}
|
||||||
|
|
||||||
|
async def async_get_onboarding(self):
|
||||||
|
"""Get onboard url if not onboarded."""
|
||||||
|
if self.entry.data.get("onboarded"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
host = self.entry.data["host"]
|
||||||
|
if self.entry.data.get("is_hassio"):
|
||||||
|
host = "/core_almond"
|
||||||
|
elif self.entry.data["type"] != TYPE_LOCAL:
|
||||||
|
host = f"{host}/me"
|
||||||
|
return {
|
||||||
|
"text": "Would you like to opt-in to share your anonymized commands with Stanford to improve Almond's responses?",
|
||||||
|
"url": f"{host}/conversation",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def async_set_onboarding(self, shown):
|
||||||
|
"""Set onboarding status."""
|
||||||
|
self.hass.config_entries.async_update_entry(
|
||||||
|
self.entry, data={**self.entry.data, "onboarded": shown}
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self, text: str, conversation_id: Optional[str] = None
|
self, text: str, conversation_id: Optional[str] = None
|
||||||
|
|
|
@ -5,7 +5,7 @@ import re
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import core
|
from homeassistant import core
|
||||||
from homeassistant.components import http
|
from homeassistant.components import http, websocket_api
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.helpers import config_validation as cv, intent
|
from homeassistant.helpers import config_validation as cv, intent
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
@ -21,6 +21,7 @@ DOMAIN = "conversation"
|
||||||
|
|
||||||
REGEX_TYPE = type(re.compile(""))
|
REGEX_TYPE = type(re.compile(""))
|
||||||
DATA_AGENT = "conversation_agent"
|
DATA_AGENT = "conversation_agent"
|
||||||
|
DATA_CONFIG = "conversation_config"
|
||||||
|
|
||||||
SERVICE_PROCESS = "process"
|
SERVICE_PROCESS = "process"
|
||||||
|
|
||||||
|
@ -39,7 +40,6 @@ CONFIG_SCHEMA = vol.Schema(
|
||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async_register = bind_hass(async_register) # pylint: disable=invalid-name
|
async_register = bind_hass(async_register) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,18 +50,19 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent):
|
||||||
hass.data[DATA_AGENT] = agent
|
hass.data[DATA_AGENT] = agent
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
||||||
|
"""Get agent."""
|
||||||
|
agent = hass.data.get(DATA_AGENT)
|
||||||
|
if agent is None:
|
||||||
|
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||||
|
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Register the process service."""
|
"""Register the process service."""
|
||||||
|
|
||||||
async def process(hass, text, conversation_id):
|
hass.data[DATA_CONFIG] = config
|
||||||
"""Process a line of text."""
|
|
||||||
agent = hass.data.get(DATA_AGENT)
|
|
||||||
|
|
||||||
if agent is None:
|
|
||||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
|
||||||
await agent.async_initialize(config)
|
|
||||||
|
|
||||||
return await agent.async_process(text, conversation_id)
|
|
||||||
|
|
||||||
async def handle_service(service):
|
async def handle_service(service):
|
||||||
"""Parse text into commands."""
|
"""Parse text into commands."""
|
||||||
|
@ -75,39 +76,89 @@ async def async_setup(hass, config):
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN, SERVICE_PROCESS, handle_service, schema=SERVICE_PROCESS_SCHEMA
|
DOMAIN, SERVICE_PROCESS, handle_service, schema=SERVICE_PROCESS_SCHEMA
|
||||||
)
|
)
|
||||||
|
hass.http.register_view(ConversationProcessView())
|
||||||
hass.http.register_view(ConversationProcessView(process))
|
hass.components.websocket_api.async_register_command(websocket_process)
|
||||||
|
hass.components.websocket_api.async_register_command(websocket_get_agent_info)
|
||||||
|
hass.components.websocket_api.async_register_command(websocket_set_onboarding)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def process(hass: core.HomeAssistant, text: str, conversation_id: str):
|
||||||
|
"""Process text and get intent."""
|
||||||
|
agent = await get_agent(hass)
|
||||||
|
return await agent.async_process(text, conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str):
|
||||||
|
"""Process text and get intent."""
|
||||||
|
try:
|
||||||
|
intent_result = await process(hass, text, conversation_id)
|
||||||
|
except intent.IntentHandleError as err:
|
||||||
|
intent_result = intent.IntentResponse()
|
||||||
|
intent_result.async_set_speech(str(err))
|
||||||
|
|
||||||
|
if intent_result is None:
|
||||||
|
intent_result = intent.IntentResponse()
|
||||||
|
intent_result.async_set_speech("Sorry, I didn't understand that")
|
||||||
|
|
||||||
|
return intent_result
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.async_response
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{"type": "conversation/process", "text": str, vol.Optional("conversation_id"): str}
|
||||||
|
)
|
||||||
|
async def websocket_process(hass, connection, msg):
|
||||||
|
"""Process text."""
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"], await get_intent(hass, msg["text"], msg.get("conversation_id"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.async_response
|
||||||
|
@websocket_api.websocket_command({"type": "conversation/agent/info"})
|
||||||
|
async def websocket_get_agent_info(hass, connection, msg):
|
||||||
|
"""Do we need onboarding."""
|
||||||
|
agent = await get_agent(hass)
|
||||||
|
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"],
|
||||||
|
{
|
||||||
|
"onboarding": await agent.async_get_onboarding(),
|
||||||
|
"attribution": agent.attribution,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.async_response
|
||||||
|
@websocket_api.websocket_command({"type": "conversation/onboarding/set", "shown": bool})
|
||||||
|
async def websocket_set_onboarding(hass, connection, msg):
|
||||||
|
"""Set onboarding status."""
|
||||||
|
agent = await get_agent(hass)
|
||||||
|
|
||||||
|
success = await agent.async_set_onboarding(msg["shown"])
|
||||||
|
|
||||||
|
if success:
|
||||||
|
connection.send_result(msg["id"])
|
||||||
|
else:
|
||||||
|
connection.send_error(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessView(http.HomeAssistantView):
|
class ConversationProcessView(http.HomeAssistantView):
|
||||||
"""View to retrieve shopping list content."""
|
"""View to process text."""
|
||||||
|
|
||||||
url = "/api/conversation/process"
|
url = "/api/conversation/process"
|
||||||
name = "api:conversation:process"
|
name = "api:conversation:process"
|
||||||
|
|
||||||
def __init__(self, process):
|
|
||||||
"""Initialize the conversation process view."""
|
|
||||||
self._process = process
|
|
||||||
|
|
||||||
@RequestDataValidator(
|
@RequestDataValidator(
|
||||||
vol.Schema({vol.Required("text"): str, vol.Optional("conversation_id"): str})
|
vol.Schema({vol.Required("text"): str, vol.Optional("conversation_id"): str})
|
||||||
)
|
)
|
||||||
async def post(self, request, data):
|
async def post(self, request, data):
|
||||||
"""Send a request for processing."""
|
"""Send a request for processing."""
|
||||||
hass = request.app["hass"]
|
hass = request.app["hass"]
|
||||||
|
intent_result = await get_intent(
|
||||||
try:
|
hass, data["text"], data.get("conversation_id")
|
||||||
intent_result = await self._process(
|
)
|
||||||
hass, data["text"], data.get("conversation_id")
|
|
||||||
)
|
|
||||||
except intent.IntentHandleError as err:
|
|
||||||
intent_result = intent.IntentResponse()
|
|
||||||
intent_result.async_set_speech(str(err))
|
|
||||||
|
|
||||||
if intent_result is None:
|
|
||||||
intent_result = intent.IntentResponse()
|
|
||||||
intent_result.async_set_speech("Sorry, I didn't understand that")
|
|
||||||
|
|
||||||
return self.json(intent_result)
|
return self.json(intent_result)
|
||||||
|
|
|
@ -8,6 +8,19 @@ from homeassistant.helpers import intent
|
||||||
class AbstractConversationAgent(ABC):
|
class AbstractConversationAgent(ABC):
|
||||||
"""Abstract conversation agent."""
|
"""Abstract conversation agent."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attribution(self):
|
||||||
|
"""Return the attribution."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_get_onboarding(self):
|
||||||
|
"""Get onboard data."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_set_onboarding(self, shown):
|
||||||
|
"""Set onboard data."""
|
||||||
|
return True
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self, text: str, conversation_id: Optional[str] = None
|
self, text: str, conversation_id: Optional[str] = None
|
||||||
|
|
Loading…
Add table
Reference in a new issue