Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
Michael Hansen
2e1d843ff8 Clean up and fix translation key 2024-11-14 09:35:18 -06:00
Michael Hansen
2e2b79fd3a Fix cloud test 2024-11-14 08:51:47 -06:00
Michael Hansen
2e3489cadc Use pipeline language 2024-11-14 08:51:47 -06:00
Michael Hansen
745861cf91 Fix type again 2024-11-14 08:51:47 -06:00
Michael Hansen
3ce5ed63e1 Fix type 2024-11-14 08:51:47 -06:00
Michael Hansen
b20e12f1c8 Check sentence triggers and local intents first 2024-11-14 08:51:47 -06:00
Michael Hansen
5ac873939f Remove from LLM 2024-11-14 08:51:47 -06:00
Michael Hansen
736a5d4a94 Handle sentence triggers and registered intents in Assist LLM API 2024-11-14 08:51:44 -06:00
10 changed files with 454 additions and 53 deletions

View file

@ -31,6 +31,7 @@ from homeassistant.components.tts import (
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.helpers.collection import (
CHANGE_UPDATED,
CollectionError,
@ -109,6 +110,7 @@ PIPELINE_FIELDS: VolDictType = {
vol.Required("tts_voice"): vol.Any(str, None),
vol.Required("wake_word_entity"): vol.Any(str, None),
vol.Required("wake_word_id"): vol.Any(str, None),
vol.Optional("prefer_local_intents"): bool,
}
STORED_PIPELINE_RUNS = 10
@ -322,6 +324,7 @@ async def async_update_pipeline(
tts_voice: str | None | UndefinedType = UNDEFINED,
wake_word_entity: str | None | UndefinedType = UNDEFINED,
wake_word_id: str | None | UndefinedType = UNDEFINED,
prefer_local_intents: bool | UndefinedType = UNDEFINED,
) -> None:
"""Update a pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
@ -345,6 +348,7 @@ async def async_update_pipeline(
("tts_voice", tts_voice),
("wake_word_entity", wake_word_entity),
("wake_word_id", wake_word_id),
("prefer_local_intents", prefer_local_intents),
)
if val is not UNDEFINED
}
@ -398,6 +402,7 @@ class Pipeline:
tts_voice: str | None
wake_word_entity: str | None
wake_word_id: str | None
prefer_local_intents: bool = False
id: str = field(default_factory=ulid_util.ulid_now)
@ -421,6 +426,7 @@ class Pipeline:
tts_voice=data["tts_voice"],
wake_word_entity=data["wake_word_entity"],
wake_word_id=data["wake_word_id"],
prefer_local_intents=data.get("prefer_local_intents", False),
)
def to_json(self) -> dict[str, Any]:
@ -438,6 +444,7 @@ class Pipeline:
"tts_voice": self.tts_voice,
"wake_word_entity": self.wake_word_entity,
"wake_word_id": self.wake_word_id,
"prefer_local_intents": self.prefer_local_intents,
}
@ -1016,15 +1023,58 @@ class PipelineRun:
)
try:
conversation_result = await conversation.async_converse(
hass=self.hass,
user_input = conversation.ConversationInput(
text=intent_input,
context=self.context,
conversation_id=conversation_id,
device_id=device_id,
context=self.context,
language=self.pipeline.conversation_language,
language=self.pipeline.language,
agent_id=self.intent_agent,
)
# Sentence triggers override conversation agent
if (
trigger_response_text
:= await conversation.async_handle_sentence_triggers(
self.hass, user_input
)
):
# Sentence trigger matched
trigger_response = intent.IntentResponse(
self.pipeline.conversation_language
)
trigger_response.async_set_speech(trigger_response_text)
conversation_result = conversation.ConversationResult(
response=trigger_response,
conversation_id=user_input.conversation_id,
)
# Try local intents first, if preferred.
# Skip this step if the default agent is already used.
elif (
self.pipeline.prefer_local_intents
and (user_input.agent_id != conversation.HOME_ASSISTANT_AGENT)
and (
intent_response := await conversation.async_handle_intents(
self.hass, user_input
)
)
):
# Local intent matched
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)
else:
# Fall back to pipeline conversation agent
conversation_result = await conversation.async_converse(
hass=self.hass,
text=user_input.text,
conversation_id=user_input.conversation_id,
device_id=user_input.device_id,
context=user_input.context,
language=user_input.language,
agent_id=user_input.agent_id,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")
raise IntentRecognitionError(

View file

@ -1,6 +1,7 @@
"""Handle Cloud assist pipelines."""
import asyncio
from typing import Any
from homeassistant.components.assist_pipeline import (
async_create_default_pipeline,
@ -98,7 +99,7 @@ async def async_migrate_cloud_pipeline_engine(
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
kwargs: dict[str, str] = {pipeline_attribute: engine_id}
kwargs: dict[str, Any] = {pipeline_attribute: engine_id}
pipelines = async_get_pipelines(hass)
for pipeline in pipelines:
if getattr(pipeline, pipeline_attribute) == DOMAIN:

View file

@ -44,7 +44,7 @@ from .const import (
SERVICE_RELOAD,
ConversationEntityFeature,
)
from .default_agent import async_setup_default_agent
from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
@ -207,6 +207,32 @@ async def async_prepare_agent(
await agent.async_prepare(language)
async def async_handle_sentence_triggers(
hass: HomeAssistant, user_input: ConversationInput
) -> str | None:
"""Try to match input against sentence triggers and return response text.
Returns None if no match occurred.
"""
default_agent = async_get_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return await default_agent.async_handle_sentence_triggers(user_input)
async def async_handle_intents(
hass: HomeAssistant, user_input: ConversationInput
) -> intent.IntentResponse | None:
"""Try to match input against registered intents and return response.
Returns None if no match occurred.
"""
default_agent = async_get_agent(hass)
assert isinstance(default_agent, DefaultAgent)
return await default_agent.async_handle_intents(user_input)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)

View file

@ -214,10 +214,15 @@ class DefaultAgent(ConversationEntity):
]
async def async_recognize(
self, user_input: ConversationInput
self,
user_input: ConversationInput,
strict_intents_only: bool = False,
match_sentence_triggers: bool = True,
) -> RecognizeResult | SentenceTriggerResult | None:
"""Recognize intent from user input."""
if trigger_result := await self._match_triggers(user_input.text):
if (match_sentence_triggers) and (
trigger_result := await self._match_triggers(user_input.text)
):
return trigger_result
language = user_input.language or self.hass.config.language
@ -240,6 +245,7 @@ class DefaultAgent(ConversationEntity):
slot_lists,
intent_context,
language,
strict_intents_only,
)
_LOGGER.debug(
@ -251,56 +257,33 @@ class DefaultAgent(ConversationEntity):
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
language = user_input.language or self.hass.config.language
conversation_id = None # Not supported
result = await self.async_recognize(user_input)
# Check if a trigger matched
if isinstance(result, SentenceTriggerResult):
# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
result.sentence, trigger_result, user_input.device_id
)
for trigger_id, trigger_result in result.matched_triggers.items()
]
# Use first non-empty result as response.
#
# There may be multiple copies of a trigger running when editing in
# the UI, so it's critical that we filter out empty responses here.
response_text: str | None = None
response_set_by_trigger = False
for trigger_future in asyncio.as_completed(trigger_callbacks):
trigger_response = await trigger_future
if trigger_response is None:
continue
response_text = trigger_response
response_set_by_trigger = True
break
# Process callbacks and get response
response_text = await self._handle_trigger_result(result, user_input)
# Convert to conversation result
response = intent.IntentResponse(language=language)
response = intent.IntentResponse(
language=user_input.language or self.hass.config.language
)
response.response_type = intent.IntentResponseType.ACTION_DONE
if response_set_by_trigger:
# Response was explicitly set to empty
response_text = response_text or ""
elif not response_text:
# Use translated acknowledgment for pipeline language
translations = await translation.async_get_translations(
self.hass, language, DOMAIN, [DOMAIN]
)
response_text = translations.get(
f"component.{DOMAIN}.conversation.agent.done", "Done"
)
response.async_set_speech(response_text)
return ConversationResult(response=response)
return await self._async_process_intent_result(result, user_input)
async def _async_process_intent_result(
self,
result: RecognizeResult | None,
user_input: ConversationInput,
) -> ConversationResult:
"""Process user input with intents."""
language = user_input.language or self.hass.config.language
conversation_id = None # Not supported
# Intent match or failure
lang_intents = await self.async_get_or_load_intents(language)
@ -436,6 +419,7 @@ class DefaultAgent(ConversationEntity):
slot_lists: dict[str, SlotList],
intent_context: dict[str, Any] | None,
language: str,
strict_intents_only: bool,
) -> RecognizeResult | None:
"""Search intents for a match to user input."""
strict_result = self._recognize_strict(
@ -446,6 +430,9 @@ class DefaultAgent(ConversationEntity):
# Successful strict match
return strict_result
if strict_intents_only:
return None
# Try again with all entities (including unexposed)
entity_registry = er.async_get(self.hass)
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
@ -1098,6 +1085,82 @@ class DefaultAgent(ConversationEntity):
return SentenceTriggerResult(sentence, matched_template, matched_triggers)
async def _handle_trigger_result(
self, result: SentenceTriggerResult, user_input: ConversationInput
) -> str:
"""Run sentence trigger callbacks and return response text."""
# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
user_input.text, trigger_result, user_input.device_id
)
for trigger_id, trigger_result in result.matched_triggers.items()
]
# Use first non-empty result as response.
#
# There may be multiple copies of a trigger running when editing in
# the UI, so it's critical that we filter out empty responses here.
response_text = ""
response_set_by_trigger = False
for trigger_future in asyncio.as_completed(trigger_callbacks):
trigger_response = await trigger_future
if trigger_response is None:
continue
response_text = trigger_response
response_set_by_trigger = True
break
if response_set_by_trigger:
# Response was explicitly set to empty
response_text = response_text or ""
elif not response_text:
# Use translated acknowledgment for pipeline language
language = user_input.language or self.hass.config.language
translations = await translation.async_get_translations(
self.hass, language, DOMAIN, [DOMAIN]
)
response_text = translations.get(
f"component.{DOMAIN}.conversation.agent.done", "Done"
)
return response_text
async def async_handle_sentence_triggers(
self, user_input: ConversationInput
) -> str | None:
"""Try to input sentence against sentence triggers and return response text.
Returns None if no match occurred.
"""
if trigger_result := await self._match_triggers(user_input.text):
return await self._handle_trigger_result(trigger_result, user_input)
return None
async def async_handle_intents(
self,
user_input: ConversationInput,
) -> intent.IntentResponse | None:
"""Try to match sentence against registered intents and return response.
Only performs strict matching with exposed entities and exact wording.
Returns None if no match occurred.
"""
result = await self.async_recognize(
user_input, strict_intents_only=True, match_sentence_triggers=False
)
if not isinstance(result, RecognizeResult):
# No error message on failed match
return None
conversation_result = await self._async_process_intent_result(
result, user_input
)
return conversation_result.response
def _make_error_result(
language: str,
@ -1108,7 +1171,6 @@ def _make_error_result(
"""Create conversation result with error code and text."""
response = intent.IntentResponse(language=language)
response.async_set_error(error_code, response_text)
return ConversationResult(response, conversation_id)

View file

@ -139,7 +139,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
@ -228,7 +228,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({

View file

@ -11,13 +11,20 @@ import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, media_source, stt, tts
from homeassistant.components import (
assist_pipeline,
conversation,
media_source,
stt,
tts,
)
from homeassistant.components.assist_pipeline.const import (
BYTES_PER_CHUNK,
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from .conftest import (
@ -927,3 +934,148 @@ async def test_tts_dict_preferred_format(
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that sentence triggers are checked before the conversation agent."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"test trigger sentence",
],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Sentence trigger should have been handled
mock_async_converse.assert_not_called()
# Verify sentence trigger response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "test trigger response"
)
async def test_prefer_local_intents(
hass: HomeAssistant,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that the default agent is checked first when local intents are preferred."""
events: list[assist_pipeline.PipelineEvent] = []
# Reuse custom sentences in test config
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
response = intent_obj.create_response()
response.async_set_speech("Order confirmed")
return response
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Test agent should not have been called
mock_async_converse.assert_not_called()
# Verify local intent response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "Order confirmed"
)

View file

@ -574,6 +574,7 @@ async def test_update_pipeline(
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
"prefer_local_intents": False,
}
await async_update_pipeline(
@ -617,6 +618,7 @@ async def test_update_pipeline(
"tts_voice": "test_voice",
"wake_word_entity": "wake_work.test_1",
"wake_word_id": "wake_word_id_1",
"prefer_local_intents": False,
}

View file

@ -974,6 +974,7 @@ async def test_add_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": True,
}
)
msg = await client.receive_json()
@ -991,6 +992,7 @@ async def test_add_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": True,
}
assert len(pipeline_store.data) == 2
@ -1008,6 +1010,7 @@ async def test_add_pipeline(
tts_voice="Arnold Schwarzenegger",
wake_word_entity="wakeword_entity_1",
wake_word_id="wakeword_id_1",
prefer_local_intents=True,
)
await client.send_json_auto_id(
@ -1195,6 +1198,7 @@ async def test_get_pipeline(
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
# Get conversation agent as pipeline
@ -1220,6 +1224,7 @@ async def test_get_pipeline(
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
await client.send_json_auto_id(
@ -1249,6 +1254,7 @@ async def test_get_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": False,
}
)
msg = await client.receive_json()
@ -1277,6 +1283,7 @@ async def test_get_pipeline(
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": "wakeword_entity_1",
"wake_word_id": "wakeword_id_1",
"prefer_local_intents": False,
}
@ -1304,6 +1311,7 @@ async def test_list_pipelines(
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
],
"preferred_pipeline": ANY,
@ -1395,6 +1403,7 @@ async def test_update_pipeline(
"tts_voice": "new_tts_voice",
"wake_word_entity": "new_wakeword_entity",
"wake_word_id": "new_wakeword_id",
"prefer_local_intents": False,
}
assert len(pipeline_store.data) == 2
@ -1446,6 +1455,7 @@ async def test_update_pipeline(
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
}
pipeline = pipeline_store.data[pipeline_id]

View file

@ -35,6 +35,7 @@ PIPELINE_DATA = {
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
},
{
"conversation_engine": "conversation_engine_2",
@ -49,6 +50,7 @@ PIPELINE_DATA = {
"tts_voice": "The Voice",
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
},
{
"conversation_engine": "conversation_engine_3",
@ -63,6 +65,7 @@ PIPELINE_DATA = {
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
"prefer_local_intents": False,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",

View file

@ -8,10 +8,15 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import default_agent
from homeassistant.components.conversation import (
ConversationInput,
async_handle_intents,
async_handle_sentence_triggers,
default_agent,
)
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
@ -229,3 +234,93 @@ async def test_prepare_agent(
await conversation.async_prepare_agent(hass, agent_id, "en")
assert len(mock_prepare.mock_calls) == 1
async def test_async_handle_sentence_triggers(hass: HomeAssistant) -> None:
"""Test handling sentence triggers with async_handle_sentence_triggers."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
response_template = "response {{ trigger.device_id }}"
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["my trigger"],
},
"action": {
"set_conversation_response": response_template,
},
}
},
)
# Device id will be available in response template
device_id = "1234"
expected_response = f"response {device_id}"
actual_response = await async_handle_sentence_triggers(
hass,
ConversationInput(
text="my trigger",
context=Context(),
conversation_id=None,
device_id=device_id,
language=hass.config.language,
),
)
assert actual_response == expected_response
async def test_async_handle_intents(hass: HomeAssistant) -> None:
"""Test handling registered intents with async_handle_intents."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
# Reuse custom sentences in test config to trigger default agent.
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
def __init__(self) -> None:
super().__init__()
self.was_handled = False
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
self.was_handled = True
return intent_obj.create_response()
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Registered intent will be handled
result = await async_handle_intents(
hass,
ConversationInput(
text="I'd like to order a stout",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
),
)
assert result is not None
assert result.intent is not None
assert result.intent.intent_type == handler.intent_type
assert handler.was_handled
# No error messages, just None as a result
result = await async_handle_intents(
hass,
ConversationInput(
text="this sentence does not exist",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
),
)
assert result is None