Add device_id to sentence trigger and external conversation APIs (#113094)

* Add device_id to sentence trigger and external conversation APIs

* Remove device_id from external API

* Update tests/components/conversation/snapshots/test_init.ambr

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-03-12 07:50:06 -05:00 committed by GitHub
parent 120525e94f
commit 556855f54e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 61 additions and 13 deletions

View file

@ -53,7 +53,9 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"] _ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[[str, RecognizeResult], Awaitable[str | None]] TRIGGER_CALLBACK_TYPE = Callable[
[str, RecognizeResult, str | None], Awaitable[str | None]
]
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence" METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file" METADATA_CUSTOM_FILE = "hass_custom_file"
@ -224,7 +226,7 @@ class DefaultAgent(AbstractConversationAgent):
# Gather callback responses in parallel # Gather callback responses in parallel
trigger_callbacks = [ trigger_callbacks = [
self._trigger_sentences[trigger_id].callback( self._trigger_sentences[trigger_id].callback(
result.sentence, trigger_result result.sentence, trigger_result, user_input.device_id
) )
for trigger_id, trigger_result in result.matched_triggers.items() for trigger_id, trigger_result in result.matched_triggers.items()
] ]

View file

@ -62,7 +62,9 @@ async def async_attach_trigger(
job = HassJob(action) job = HassJob(action)
async def call_action(sentence: str, result: RecognizeResult) -> str | None: async def call_action(
sentence: str, result: RecognizeResult, device_id: str | None
) -> str | None:
"""Call action with right context.""" """Call action with right context."""
# Add slot values as extra trigger data # Add slot values as extra trigger data
@ -70,9 +72,11 @@ async def async_attach_trigger(
entity_name: { entity_name: {
"name": entity_name, "name": entity_name,
"text": entity.text.strip(), # remove whitespace "text": entity.text.strip(), # remove whitespace
"value": entity.value.strip() "value": (
if isinstance(entity.value, str) entity.value.strip()
else entity.value, if isinstance(entity.value, str)
else entity.value
),
} }
for entity_name, entity in result.entities.items() for entity_name, entity in result.entities.items()
} }
@ -85,6 +89,7 @@ async def async_attach_trigger(
"slots": { # direct access to values "slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items() entity_name: entity["value"] for entity_name, entity in details.items()
}, },
"device_id": device_id,
} }
# Wait for the automation to complete # Wait for the automation to complete

View file

@ -535,9 +535,12 @@ async def test_turn_on_intent(
async def test_service_fails(hass: HomeAssistant, init_components) -> None: async def test_service_fails(hass: HomeAssistant, init_components) -> None:
"""Test calling the turn on intent.""" """Test calling the turn on intent."""
with pytest.raises(HomeAssistantError), patch( with (
"homeassistant.components.conversation.async_converse", pytest.raises(HomeAssistantError),
side_effect=intent.IntentHandleError, patch(
"homeassistant.components.conversation.async_converse",
side_effect=intent.IntentHandleError,
),
): ):
await hass.services.async_call( await hass.services.async_call(
"conversation", "conversation",

View file

@ -5,7 +5,8 @@ import logging
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import trigger from homeassistant.helpers import trigger
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -51,9 +52,7 @@ async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None
service_response = await hass.services.async_call( service_response = await hass.services.async_call(
"conversation", "conversation",
"process", "process",
{ {"text": "Ha ha ha"},
"text": "Ha ha ha",
},
blocking=True, blocking=True,
return_response=True, return_response=True,
) )
@ -69,6 +68,7 @@ async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None
"sentence": "Ha ha ha", "sentence": "Ha ha ha",
"slots": {}, "slots": {},
"details": {}, "details": {},
"device_id": None,
} }
@ -160,6 +160,7 @@ async def test_response_same_sentence(hass: HomeAssistant, calls, setup_comp) ->
"sentence": "test sentence", "sentence": "test sentence",
"slots": {}, "slots": {},
"details": {}, "details": {},
"device_id": None,
} }
@ -311,6 +312,7 @@ async def test_same_trigger_multiple_sentences(
"sentence": "hello", "sentence": "hello",
"slots": {}, "slots": {},
"details": {}, "details": {},
"device_id": None,
} }
@ -488,4 +490,40 @@ async def test_wildcards(hass: HomeAssistant, calls, setup_comp) -> None:
"value": "the beatles", "value": "the beatles",
}, },
}, },
"device_id": None,
} }
async def test_trigger_with_device_id(hass: HomeAssistant) -> None:
"""Test that a trigger receives a device_id."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["test sentence"],
},
"action": {
"set_conversation_response": "{{ trigger.device_id }}",
},
}
},
)
agent = await conversation._get_agent_manager(hass).async_get_agent()
assert isinstance(agent, conversation.DefaultAgent)
result = await agent.async_process(
conversation.ConversationInput(
text="test sentence",
context=Context(),
conversation_id=None,
device_id="my_device",
language=hass.config.language,
)
)
assert result.response.speech["plain"]["speech"] == "my_device"