diff --git a/homeassistant/components/almond/__init__.py b/homeassistant/components/almond/__init__.py index 7c1f65f3ac3..66d4b2fc9af 100644 --- a/homeassistant/components/almond/__init__.py +++ b/homeassistant/components/almond/__init__.py @@ -10,7 +10,7 @@ from aiohttp import ClientSession, ClientError from pyalmond import AlmondLocalAuth, AbstractAlmondWebAuth, WebAlmondAPI import voluptuous as vol -from homeassistant.core import HomeAssistant, CoreState +from homeassistant.core import HomeAssistant, CoreState, Context from homeassistant.const import CONF_TYPE, CONF_HOST, EVENT_HOMEASSISTANT_START from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.auth.const import GROUP_ID_ADMIN @@ -277,7 +277,7 @@ class AlmondAgent(conversation.AbstractConversationAgent): return True async def async_process( - self, text: str, conversation_id: Optional[str] = None + self, text: str, context: Context, conversation_id: Optional[str] = None ) -> intent.IntentResponse: """Process a sentence.""" response = await self.api.async_converse_text(text, conversation_id) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index b2648d3d1f6..ba8b211e65a 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -50,15 +50,6 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent): hass.data[DATA_AGENT] = agent -async def get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent: - """Get agent.""" - agent = hass.data.get(DATA_AGENT) - if agent is None: - agent = hass.data[DATA_AGENT] = DefaultAgent(hass) - await agent.async_initialize(hass.data.get(DATA_CONFIG)) - return agent - - async def async_setup(hass, config): """Register the process service.""" hass.data[DATA_CONFIG] = config @@ -67,8 +58,9 @@ async def async_setup(hass, config): """Parse text into commands.""" text = service.data[ATTR_TEXT] _LOGGER.debug("Processing: <%s>", text) + agent = await _get_agent(hass) try: - await process(hass, text, service.context.id) + await agent.async_process(text, service.context) except intent.IntentHandleError as err: _LOGGER.error("Error processing %s: %s", text, err) @@ -84,27 +76,6 @@ async def async_setup(hass, config): return True -async def process(hass: core.HomeAssistant, text: str, conversation_id: str): - """Process text and get intent.""" - agent = await get_agent(hass) - return await agent.async_process(text, conversation_id) - - -async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str): - """Process text and get intent.""" - try: - intent_result = await process(hass, text, conversation_id) - except intent.IntentHandleError as err: - intent_result = intent.IntentResponse() - intent_result.async_set_speech(str(err)) - - if intent_result is None: - intent_result = intent.IntentResponse() - intent_result.async_set_speech("Sorry, I didn't understand that") - - return intent_result - - @websocket_api.async_response @websocket_api.websocket_command( {"type": "conversation/process", "text": str, vol.Optional("conversation_id"): str} @@ -112,7 +83,10 @@ async def get_intent(hass: core.HomeAssistant, text: str, conversation_id: str): async def websocket_process(hass, connection, msg): """Process text.""" connection.send_result( - msg["id"], await get_intent(hass, msg["text"], msg.get("conversation_id")) + msg["id"], + await _async_converse( + hass, msg["text"], msg.get("conversation_id"), connection.context(msg) + ), ) @@ -120,7 +94,7 @@ async def websocket_process(hass, connection, msg): @websocket_api.websocket_command({"type": "conversation/agent/info"}) async def websocket_get_agent_info(hass, connection, msg): """Do we need onboarding.""" - agent = await get_agent(hass) + agent = await _get_agent(hass) connection.send_result( msg["id"], @@ -135,7 +109,7 @@ async def websocket_get_agent_info(hass, connection, msg): @websocket_api.websocket_command({"type": "conversation/onboarding/set", "shown": bool}) async def websocket_set_onboarding(hass, connection, msg): """Set onboarding status.""" - agent = await get_agent(hass) + agent = await _get_agent(hass) success = await agent.async_set_onboarding(msg["shown"]) @@ -157,8 +131,9 @@ class ConversationProcessView(http.HomeAssistantView): async def post(self, request, data): """Send a request for processing.""" hass = request.app["hass"] - intent_result = await get_intent( - hass, data["text"], data.get("conversation_id") + + intent_result = await _async_converse( + hass, data["text"], data.get("conversation_id"), self.context(request) ) return self.json(intent_result) @@ -188,7 +163,7 @@ class ConversationHandleView(http.HomeAssistantView): key: {"value": value} for key, value in data.get("data", {}).items() } intent_result = await intent.async_handle( - hass, DOMAIN, intent_name, slots, "" + hass, DOMAIN, intent_name, slots, "", self.context(request) ) except intent.IntentHandleError as err: intent_result = intent.IntentResponse() @@ -199,3 +174,30 @@ class ConversationHandleView(http.HomeAssistantView): intent_result.async_set_speech("Sorry, I couldn't handle that") return self.json(intent_result) + + +async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent: + """Get the active conversation agent.""" + agent = hass.data.get(DATA_AGENT) + if agent is None: + agent = hass.data[DATA_AGENT] = DefaultAgent(hass) + await agent.async_initialize(hass.data.get(DATA_CONFIG)) + return agent + + +async def _async_converse( + hass: core.HomeAssistant, text: str, conversation_id: str, context: core.Context +) -> intent.IntentResponse: + """Process text and get intent.""" + agent = await _get_agent(hass) + try: + intent_result = await agent.async_process(text, context, conversation_id) + except intent.IntentHandleError as err: + intent_result = intent.IntentResponse() + intent_result.async_set_speech(str(err)) + + if intent_result is None: + intent_result = intent.IntentResponse() + intent_result.async_set_speech("Sorry, I didn't understand that") + + return intent_result diff --git a/homeassistant/components/conversation/agent.py b/homeassistant/components/conversation/agent.py index 0c47d615645..c9c2ab46cf9 100644 --- a/homeassistant/components/conversation/agent.py +++ b/homeassistant/components/conversation/agent.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional +from homeassistant.core import Context from homeassistant.helpers import intent @@ -23,6 +24,6 @@ class AbstractConversationAgent(ABC): @abstractmethod async def async_process( - self, text: str, conversation_id: Optional[str] = None + self, text: str, context: Context, conversation_id: Optional[str] = None ) -> intent.IntentResponse: """Process a sentence.""" diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index c202cdf1e65..e562eed7e66 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -109,7 +109,7 @@ class DefaultAgent(AbstractConversationAgent): async_register(self.hass, intent_type, sentences) async def async_process( - self, text: str, conversation_id: Optional[str] = None + self, text: str, context: core.Context, conversation_id: Optional[str] = None ) -> intent.IntentResponse: """Process a sentence.""" intents = self.hass.data[DOMAIN] @@ -127,4 +127,5 @@ class DefaultAgent(AbstractConversationAgent): intent_type, {key: {"value": value} for key, value in match.groupdict().items()}, text, + context, ) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index 804c90d4f96..31f96833667 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -34,8 +34,8 @@ class HomeAssistantView: requires_auth = True cors_allowed = False - # pylint: disable=no-self-use - def context(self, request): + @staticmethod + def context(request): """Generate a context from a request.""" user = request.get("hass_user") if user is None: @@ -43,7 +43,8 @@ class HomeAssistantView: return Context(user_id=user.id) - def json(self, result, status_code=200, headers=None): + @staticmethod + def json(result, status_code=200, headers=None): """Return a JSON response.""" try: msg = json.dumps( diff --git a/homeassistant/components/intent_script/__init__.py b/homeassistant/components/intent_script/__init__.py index 75a0c0e8f97..ce4b8b27a51 100644 --- a/homeassistant/components/intent_script/__init__.py +++ b/homeassistant/components/intent_script/__init__.py @@ -80,7 +80,9 @@ class ScriptIntentHandler(intent.IntentHandler): if action is not None: if is_async_action: - intent_obj.hass.async_create_task(action.async_run(slots)) + intent_obj.hass.async_create_task( + action.async_run(slots, intent_obj.context) + ) else: await action.async_run(slots) diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py index 2ca5e496b10..b33cb29421e 100644 --- a/homeassistant/components/light/__init__.py +++ b/homeassistant/components/light/__init__.py @@ -232,7 +232,9 @@ class SetIntentHandler(intent.IntentHandler): service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"] speech_parts.append("{}% brightness".format(slots["brightness"]["value"])) - await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, service_data) + await hass.services.async_call( + DOMAIN, SERVICE_TURN_ON, service_data, context=intent_obj.context + ) response = intent_obj.create_response() diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index dc48d825348..12b346603f0 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterable, Optional import voluptuous as vol from homeassistant.const import ATTR_SUPPORTED_FEATURES -from homeassistant.core import callback, State, T +from homeassistant.core import callback, State, T, Context from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import HomeAssistantType @@ -53,6 +53,7 @@ async def async_handle( intent_type: str, slots: Optional[_SlotsType] = None, text_input: Optional[str] = None, + context: Optional[Context] = None, ) -> "IntentResponse": """Handle an intent.""" handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type) @@ -60,7 +61,10 @@ async def async_handle( if handler is None: raise UnknownIntent(f"Unknown intent {intent_type}") - intent = Intent(hass, platform, intent_type, slots or {}, text_input) + if context is None: + context = Context() + + intent = Intent(hass, platform, intent_type, slots or {}, text_input, context) try: _LOGGER.info("Triggering intent handler %s", handler) @@ -196,7 +200,10 @@ class ServiceIntentHandler(IntentHandler): state = async_match_state(hass, slots["name"]["value"]) await hass.services.async_call( - self.domain, self.service, {ATTR_ENTITY_ID: state.entity_id} + self.domain, + self.service, + {ATTR_ENTITY_ID: state.entity_id}, + context=intent_obj.context, ) response = intent_obj.create_response() @@ -207,7 +214,7 @@ class ServiceIntentHandler(IntentHandler): class Intent: """Hold the intent.""" - __slots__ = ["hass", "platform", "intent_type", "slots", "text_input"] + __slots__ = ["hass", "platform", "intent_type", "slots", "text_input", "context"] def __init__( self, @@ -216,6 +223,7 @@ class Intent: intent_type: str, slots: _SlotsType, text_input: Optional[str], + context: Context, ) -> None: """Initialize an intent.""" self.hass = hass @@ -223,6 +231,7 @@ class Intent: self.intent_type = intent_type self.slots = slots self.text_input = text_input + self.context = context @callback def create_response(self) -> "IntentResponse": diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index fc6508159ea..3982ed6f699 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -2,7 +2,7 @@ # pylint: disable=protected-access import pytest -from homeassistant.core import DOMAIN as HASS_DOMAIN +from homeassistant.core import DOMAIN as HASS_DOMAIN, Context from homeassistant.setup import async_setup_component from homeassistant.components import conversation from homeassistant.components.cover import SERVICE_OPEN_COVER @@ -25,10 +25,13 @@ async def test_calling_intent(hass): ) assert result + context = Context() + await hass.services.async_call( "conversation", "process", {conversation.ATTR_TEXT: "I would like the Grolsch beer"}, + context=context, ) await hass.async_block_till_done() @@ -38,6 +41,7 @@ async def test_calling_intent(hass): assert intent.intent_type == "OrderBeer" assert intent.slots == {"type": {"value": "Grolsch"}} assert intent.text_input == "I would like the Grolsch beer" + assert intent.context is context async def test_register_before_setup(hass): @@ -80,7 +84,7 @@ async def test_register_before_setup(hass): assert intent.text_input == "I would like the Grolsch beer" -async def test_http_processing_intent(hass, hass_client): +async def test_http_processing_intent(hass, hass_client, hass_admin_user): """Test processing intent via HTTP API.""" class TestIntentHandler(intent.IntentHandler): @@ -90,6 +94,7 @@ async def test_http_processing_intent(hass, hass_client): async def async_handle(self, intent): """Handle the intent.""" + assert intent.context.user_id == hass_admin_user.id response = intent.create_response() response.async_set_speech( "I've ordered a {}!".format(intent.slots["type"]["value"]) @@ -124,7 +129,7 @@ async def test_http_processing_intent(hass, hass_client): } -async def test_http_handle_intent(hass, hass_client): +async def test_http_handle_intent(hass, hass_client, hass_admin_user): """Test handle intent via HTTP API.""" class TestIntentHandler(intent.IntentHandler): @@ -134,6 +139,7 @@ async def test_http_handle_intent(hass, hass_client): async def async_handle(self, intent): """Handle the intent.""" + assert intent.context.user_id == hass_admin_user.id response = intent.create_response() response.async_set_speech( "I've ordered a {}!".format(intent.slots["type"]["value"]) @@ -308,7 +314,7 @@ async def test_http_api_wrong_data(hass, hass_client): assert resp.status == 400 -async def test_custom_agent(hass, hass_client): +async def test_custom_agent(hass, hass_client, hass_admin_user): """Test a custom conversation agent.""" calls = [] @@ -316,9 +322,9 @@ async def test_custom_agent(hass, hass_client): class MyAgent(conversation.AbstractConversationAgent): """Test Agent.""" - async def async_process(self, text, conversation_id): + async def async_process(self, text, context, conversation_id): """Process some text.""" - calls.append((text, conversation_id)) + calls.append((text, context, conversation_id)) response = intent.IntentResponse() response.async_set_speech("Test response") return response @@ -341,4 +347,5 @@ async def test_custom_agent(hass, hass_client): assert len(calls) == 1 assert calls[0][0] == "Test Text" - assert calls[0][1] == "test-conv-id" + assert calls[0][1].user_id == hass_admin_user.id + assert calls[0][2] == "test-conv-id"