Compare commits
8 commits
dev
...
synesthesi
Author | SHA1 | Date | |
---|---|---|---|
|
2e1d843ff8 | ||
|
2e2b79fd3a | ||
|
2e3489cadc | ||
|
745861cf91 | ||
|
3ce5ed63e1 | ||
|
b20e12f1c8 | ||
|
5ac873939f | ||
|
736a5d4a94 |
10 changed files with 454 additions and 53 deletions
|
@ -31,6 +31,7 @@ from homeassistant.components.tts import (
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.helpers.collection import (
|
from homeassistant.helpers.collection import (
|
||||||
CHANGE_UPDATED,
|
CHANGE_UPDATED,
|
||||||
CollectionError,
|
CollectionError,
|
||||||
|
@ -109,6 +110,7 @@ PIPELINE_FIELDS: VolDictType = {
|
||||||
vol.Required("tts_voice"): vol.Any(str, None),
|
vol.Required("tts_voice"): vol.Any(str, None),
|
||||||
vol.Required("wake_word_entity"): vol.Any(str, None),
|
vol.Required("wake_word_entity"): vol.Any(str, None),
|
||||||
vol.Required("wake_word_id"): vol.Any(str, None),
|
vol.Required("wake_word_id"): vol.Any(str, None),
|
||||||
|
vol.Optional("prefer_local_intents"): bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
STORED_PIPELINE_RUNS = 10
|
STORED_PIPELINE_RUNS = 10
|
||||||
|
@ -322,6 +324,7 @@ async def async_update_pipeline(
|
||||||
tts_voice: str | None | UndefinedType = UNDEFINED,
|
tts_voice: str | None | UndefinedType = UNDEFINED,
|
||||||
wake_word_entity: str | None | UndefinedType = UNDEFINED,
|
wake_word_entity: str | None | UndefinedType = UNDEFINED,
|
||||||
wake_word_id: str | None | UndefinedType = UNDEFINED,
|
wake_word_id: str | None | UndefinedType = UNDEFINED,
|
||||||
|
prefer_local_intents: bool | UndefinedType = UNDEFINED,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a pipeline."""
|
"""Update a pipeline."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
@ -345,6 +348,7 @@ async def async_update_pipeline(
|
||||||
("tts_voice", tts_voice),
|
("tts_voice", tts_voice),
|
||||||
("wake_word_entity", wake_word_entity),
|
("wake_word_entity", wake_word_entity),
|
||||||
("wake_word_id", wake_word_id),
|
("wake_word_id", wake_word_id),
|
||||||
|
("prefer_local_intents", prefer_local_intents),
|
||||||
)
|
)
|
||||||
if val is not UNDEFINED
|
if val is not UNDEFINED
|
||||||
}
|
}
|
||||||
|
@ -398,6 +402,7 @@ class Pipeline:
|
||||||
tts_voice: str | None
|
tts_voice: str | None
|
||||||
wake_word_entity: str | None
|
wake_word_entity: str | None
|
||||||
wake_word_id: str | None
|
wake_word_id: str | None
|
||||||
|
prefer_local_intents: bool = False
|
||||||
|
|
||||||
id: str = field(default_factory=ulid_util.ulid_now)
|
id: str = field(default_factory=ulid_util.ulid_now)
|
||||||
|
|
||||||
|
@ -421,6 +426,7 @@ class Pipeline:
|
||||||
tts_voice=data["tts_voice"],
|
tts_voice=data["tts_voice"],
|
||||||
wake_word_entity=data["wake_word_entity"],
|
wake_word_entity=data["wake_word_entity"],
|
||||||
wake_word_id=data["wake_word_id"],
|
wake_word_id=data["wake_word_id"],
|
||||||
|
prefer_local_intents=data.get("prefer_local_intents", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_json(self) -> dict[str, Any]:
|
def to_json(self) -> dict[str, Any]:
|
||||||
|
@ -438,6 +444,7 @@ class Pipeline:
|
||||||
"tts_voice": self.tts_voice,
|
"tts_voice": self.tts_voice,
|
||||||
"wake_word_entity": self.wake_word_entity,
|
"wake_word_entity": self.wake_word_entity,
|
||||||
"wake_word_id": self.wake_word_id,
|
"wake_word_id": self.wake_word_id,
|
||||||
|
"prefer_local_intents": self.prefer_local_intents,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1016,15 +1023,58 @@ class PipelineRun:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conversation_result = await conversation.async_converse(
|
user_input = conversation.ConversationInput(
|
||||||
hass=self.hass,
|
|
||||||
text=intent_input,
|
text=intent_input,
|
||||||
|
context=self.context,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
context=self.context,
|
language=self.pipeline.language,
|
||||||
language=self.pipeline.conversation_language,
|
|
||||||
agent_id=self.intent_agent,
|
agent_id=self.intent_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sentence triggers override conversation agent
|
||||||
|
if (
|
||||||
|
trigger_response_text
|
||||||
|
:= await conversation.async_handle_sentence_triggers(
|
||||||
|
self.hass, user_input
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# Sentence trigger matched
|
||||||
|
trigger_response = intent.IntentResponse(
|
||||||
|
self.pipeline.conversation_language
|
||||||
|
)
|
||||||
|
trigger_response.async_set_speech(trigger_response_text)
|
||||||
|
conversation_result = conversation.ConversationResult(
|
||||||
|
response=trigger_response,
|
||||||
|
conversation_id=user_input.conversation_id,
|
||||||
|
)
|
||||||
|
# Try local intents first, if preferred.
|
||||||
|
# Skip this step if the default agent is already used.
|
||||||
|
elif (
|
||||||
|
self.pipeline.prefer_local_intents
|
||||||
|
and (user_input.agent_id != conversation.HOME_ASSISTANT_AGENT)
|
||||||
|
and (
|
||||||
|
intent_response := await conversation.async_handle_intents(
|
||||||
|
self.hass, user_input
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# Local intent matched
|
||||||
|
conversation_result = conversation.ConversationResult(
|
||||||
|
response=intent_response,
|
||||||
|
conversation_id=user_input.conversation_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fall back to pipeline conversation agent
|
||||||
|
conversation_result = await conversation.async_converse(
|
||||||
|
hass=self.hass,
|
||||||
|
text=user_input.text,
|
||||||
|
conversation_id=user_input.conversation_id,
|
||||||
|
device_id=user_input.device_id,
|
||||||
|
context=user_input.context,
|
||||||
|
language=user_input.language,
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
)
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during intent recognition")
|
_LOGGER.exception("Unexpected error during intent recognition")
|
||||||
raise IntentRecognitionError(
|
raise IntentRecognitionError(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Handle Cloud assist pipelines."""
|
"""Handle Cloud assist pipelines."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
async_create_default_pipeline,
|
async_create_default_pipeline,
|
||||||
|
@ -98,7 +99,7 @@ async def async_migrate_cloud_pipeline_engine(
|
||||||
# is an after dependency of cloud
|
# is an after dependency of cloud
|
||||||
await async_setup_pipeline_store(hass)
|
await async_setup_pipeline_store(hass)
|
||||||
|
|
||||||
kwargs: dict[str, str] = {pipeline_attribute: engine_id}
|
kwargs: dict[str, Any] = {pipeline_attribute: engine_id}
|
||||||
pipelines = async_get_pipelines(hass)
|
pipelines = async_get_pipelines(hass)
|
||||||
for pipeline in pipelines:
|
for pipeline in pipelines:
|
||||||
if getattr(pipeline, pipeline_attribute) == DOMAIN:
|
if getattr(pipeline, pipeline_attribute) == DOMAIN:
|
||||||
|
|
|
@ -44,7 +44,7 @@ from .const import (
|
||||||
SERVICE_RELOAD,
|
SERVICE_RELOAD,
|
||||||
ConversationEntityFeature,
|
ConversationEntityFeature,
|
||||||
)
|
)
|
||||||
from .default_agent import async_setup_default_agent
|
from .default_agent import DefaultAgent, async_setup_default_agent
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
|
@ -207,6 +207,32 @@ async def async_prepare_agent(
|
||||||
await agent.async_prepare(language)
|
await agent.async_prepare(language)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_handle_sentence_triggers(
|
||||||
|
hass: HomeAssistant, user_input: ConversationInput
|
||||||
|
) -> str | None:
|
||||||
|
"""Try to match input against sentence triggers and return response text.
|
||||||
|
|
||||||
|
Returns None if no match occurred.
|
||||||
|
"""
|
||||||
|
default_agent = async_get_agent(hass)
|
||||||
|
assert isinstance(default_agent, DefaultAgent)
|
||||||
|
|
||||||
|
return await default_agent.async_handle_sentence_triggers(user_input)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_handle_intents(
|
||||||
|
hass: HomeAssistant, user_input: ConversationInput
|
||||||
|
) -> intent.IntentResponse | None:
|
||||||
|
"""Try to match input against registered intents and return response.
|
||||||
|
|
||||||
|
Returns None if no match occurred.
|
||||||
|
"""
|
||||||
|
default_agent = async_get_agent(hass)
|
||||||
|
assert isinstance(default_agent, DefaultAgent)
|
||||||
|
|
||||||
|
return await default_agent.async_handle_intents(user_input)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Register the process service."""
|
"""Register the process service."""
|
||||||
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)
|
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)
|
||||||
|
|
|
@ -214,10 +214,15 @@ class DefaultAgent(ConversationEntity):
|
||||||
]
|
]
|
||||||
|
|
||||||
async def async_recognize(
|
async def async_recognize(
|
||||||
self, user_input: ConversationInput
|
self,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
strict_intents_only: bool = False,
|
||||||
|
match_sentence_triggers: bool = True,
|
||||||
) -> RecognizeResult | SentenceTriggerResult | None:
|
) -> RecognizeResult | SentenceTriggerResult | None:
|
||||||
"""Recognize intent from user input."""
|
"""Recognize intent from user input."""
|
||||||
if trigger_result := await self._match_triggers(user_input.text):
|
if (match_sentence_triggers) and (
|
||||||
|
trigger_result := await self._match_triggers(user_input.text)
|
||||||
|
):
|
||||||
return trigger_result
|
return trigger_result
|
||||||
|
|
||||||
language = user_input.language or self.hass.config.language
|
language = user_input.language or self.hass.config.language
|
||||||
|
@ -240,6 +245,7 @@ class DefaultAgent(ConversationEntity):
|
||||||
slot_lists,
|
slot_lists,
|
||||||
intent_context,
|
intent_context,
|
||||||
language,
|
language,
|
||||||
|
strict_intents_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
|
@ -251,56 +257,33 @@ class DefaultAgent(ConversationEntity):
|
||||||
|
|
||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
language = user_input.language or self.hass.config.language
|
|
||||||
conversation_id = None # Not supported
|
|
||||||
|
|
||||||
result = await self.async_recognize(user_input)
|
result = await self.async_recognize(user_input)
|
||||||
|
|
||||||
# Check if a trigger matched
|
# Check if a trigger matched
|
||||||
if isinstance(result, SentenceTriggerResult):
|
if isinstance(result, SentenceTriggerResult):
|
||||||
# Gather callback responses in parallel
|
# Process callbacks and get response
|
||||||
trigger_callbacks = [
|
response_text = await self._handle_trigger_result(result, user_input)
|
||||||
self._trigger_sentences[trigger_id].callback(
|
|
||||||
result.sentence, trigger_result, user_input.device_id
|
|
||||||
)
|
|
||||||
for trigger_id, trigger_result in result.matched_triggers.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
# Use first non-empty result as response.
|
|
||||||
#
|
|
||||||
# There may be multiple copies of a trigger running when editing in
|
|
||||||
# the UI, so it's critical that we filter out empty responses here.
|
|
||||||
response_text: str | None = None
|
|
||||||
response_set_by_trigger = False
|
|
||||||
for trigger_future in asyncio.as_completed(trigger_callbacks):
|
|
||||||
trigger_response = await trigger_future
|
|
||||||
if trigger_response is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
response_text = trigger_response
|
|
||||||
response_set_by_trigger = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Convert to conversation result
|
# Convert to conversation result
|
||||||
response = intent.IntentResponse(language=language)
|
response = intent.IntentResponse(
|
||||||
|
language=user_input.language or self.hass.config.language
|
||||||
|
)
|
||||||
response.response_type = intent.IntentResponseType.ACTION_DONE
|
response.response_type = intent.IntentResponseType.ACTION_DONE
|
||||||
|
|
||||||
if response_set_by_trigger:
|
|
||||||
# Response was explicitly set to empty
|
|
||||||
response_text = response_text or ""
|
|
||||||
elif not response_text:
|
|
||||||
# Use translated acknowledgment for pipeline language
|
|
||||||
translations = await translation.async_get_translations(
|
|
||||||
self.hass, language, DOMAIN, [DOMAIN]
|
|
||||||
)
|
|
||||||
response_text = translations.get(
|
|
||||||
f"component.{DOMAIN}.conversation.agent.done", "Done"
|
|
||||||
)
|
|
||||||
|
|
||||||
response.async_set_speech(response_text)
|
response.async_set_speech(response_text)
|
||||||
|
|
||||||
return ConversationResult(response=response)
|
return ConversationResult(response=response)
|
||||||
|
|
||||||
|
return await self._async_process_intent_result(result, user_input)
|
||||||
|
|
||||||
|
async def _async_process_intent_result(
|
||||||
|
self,
|
||||||
|
result: RecognizeResult | None,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
) -> ConversationResult:
|
||||||
|
"""Process user input with intents."""
|
||||||
|
language = user_input.language or self.hass.config.language
|
||||||
|
conversation_id = None # Not supported
|
||||||
|
|
||||||
# Intent match or failure
|
# Intent match or failure
|
||||||
lang_intents = await self.async_get_or_load_intents(language)
|
lang_intents = await self.async_get_or_load_intents(language)
|
||||||
|
|
||||||
|
@ -436,6 +419,7 @@ class DefaultAgent(ConversationEntity):
|
||||||
slot_lists: dict[str, SlotList],
|
slot_lists: dict[str, SlotList],
|
||||||
intent_context: dict[str, Any] | None,
|
intent_context: dict[str, Any] | None,
|
||||||
language: str,
|
language: str,
|
||||||
|
strict_intents_only: bool,
|
||||||
) -> RecognizeResult | None:
|
) -> RecognizeResult | None:
|
||||||
"""Search intents for a match to user input."""
|
"""Search intents for a match to user input."""
|
||||||
strict_result = self._recognize_strict(
|
strict_result = self._recognize_strict(
|
||||||
|
@ -446,6 +430,9 @@ class DefaultAgent(ConversationEntity):
|
||||||
# Successful strict match
|
# Successful strict match
|
||||||
return strict_result
|
return strict_result
|
||||||
|
|
||||||
|
if strict_intents_only:
|
||||||
|
return None
|
||||||
|
|
||||||
# Try again with all entities (including unexposed)
|
# Try again with all entities (including unexposed)
|
||||||
entity_registry = er.async_get(self.hass)
|
entity_registry = er.async_get(self.hass)
|
||||||
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
|
all_entity_names: list[tuple[str, str, dict[str, Any]]] = []
|
||||||
|
@ -1098,6 +1085,82 @@ class DefaultAgent(ConversationEntity):
|
||||||
|
|
||||||
return SentenceTriggerResult(sentence, matched_template, matched_triggers)
|
return SentenceTriggerResult(sentence, matched_template, matched_triggers)
|
||||||
|
|
||||||
|
async def _handle_trigger_result(
|
||||||
|
self, result: SentenceTriggerResult, user_input: ConversationInput
|
||||||
|
) -> str:
|
||||||
|
"""Run sentence trigger callbacks and return response text."""
|
||||||
|
|
||||||
|
# Gather callback responses in parallel
|
||||||
|
trigger_callbacks = [
|
||||||
|
self._trigger_sentences[trigger_id].callback(
|
||||||
|
user_input.text, trigger_result, user_input.device_id
|
||||||
|
)
|
||||||
|
for trigger_id, trigger_result in result.matched_triggers.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use first non-empty result as response.
|
||||||
|
#
|
||||||
|
# There may be multiple copies of a trigger running when editing in
|
||||||
|
# the UI, so it's critical that we filter out empty responses here.
|
||||||
|
response_text = ""
|
||||||
|
response_set_by_trigger = False
|
||||||
|
for trigger_future in asyncio.as_completed(trigger_callbacks):
|
||||||
|
trigger_response = await trigger_future
|
||||||
|
if trigger_response is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_text = trigger_response
|
||||||
|
response_set_by_trigger = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if response_set_by_trigger:
|
||||||
|
# Response was explicitly set to empty
|
||||||
|
response_text = response_text or ""
|
||||||
|
elif not response_text:
|
||||||
|
# Use translated acknowledgment for pipeline language
|
||||||
|
language = user_input.language or self.hass.config.language
|
||||||
|
translations = await translation.async_get_translations(
|
||||||
|
self.hass, language, DOMAIN, [DOMAIN]
|
||||||
|
)
|
||||||
|
response_text = translations.get(
|
||||||
|
f"component.{DOMAIN}.conversation.agent.done", "Done"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
async def async_handle_sentence_triggers(
|
||||||
|
self, user_input: ConversationInput
|
||||||
|
) -> str | None:
|
||||||
|
"""Try to input sentence against sentence triggers and return response text.
|
||||||
|
|
||||||
|
Returns None if no match occurred.
|
||||||
|
"""
|
||||||
|
if trigger_result := await self._match_triggers(user_input.text):
|
||||||
|
return await self._handle_trigger_result(trigger_result, user_input)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_handle_intents(
|
||||||
|
self,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
) -> intent.IntentResponse | None:
|
||||||
|
"""Try to match sentence against registered intents and return response.
|
||||||
|
|
||||||
|
Only performs strict matching with exposed entities and exact wording.
|
||||||
|
Returns None if no match occurred.
|
||||||
|
"""
|
||||||
|
result = await self.async_recognize(
|
||||||
|
user_input, strict_intents_only=True, match_sentence_triggers=False
|
||||||
|
)
|
||||||
|
if not isinstance(result, RecognizeResult):
|
||||||
|
# No error message on failed match
|
||||||
|
return None
|
||||||
|
|
||||||
|
conversation_result = await self._async_process_intent_result(
|
||||||
|
result, user_input
|
||||||
|
)
|
||||||
|
return conversation_result.response
|
||||||
|
|
||||||
|
|
||||||
def _make_error_result(
|
def _make_error_result(
|
||||||
language: str,
|
language: str,
|
||||||
|
@ -1108,7 +1171,6 @@ def _make_error_result(
|
||||||
"""Create conversation result with error code and text."""
|
"""Create conversation result with error code and text."""
|
||||||
response = intent.IntentResponse(language=language)
|
response = intent.IntentResponse(language=language)
|
||||||
response.async_set_error(error_code, response_text)
|
response.async_set_error(error_code, response_text)
|
||||||
|
|
||||||
return ConversationResult(response, conversation_id)
|
return ConversationResult(response, conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,7 @@
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'code': 'no_intent_match',
|
'code': 'no_intent_match',
|
||||||
}),
|
}),
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'response_type': 'error',
|
'response_type': 'error',
|
||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
|
@ -228,7 +228,7 @@
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'code': 'no_intent_match',
|
'code': 'no_intent_match',
|
||||||
}),
|
}),
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'response_type': 'error',
|
'response_type': 'error',
|
||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
|
|
|
@ -11,13 +11,20 @@ import wave
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, media_source, stt, tts
|
from homeassistant.components import (
|
||||||
|
assist_pipeline,
|
||||||
|
conversation,
|
||||||
|
media_source,
|
||||||
|
stt,
|
||||||
|
tts,
|
||||||
|
)
|
||||||
from homeassistant.components.assist_pipeline.const import (
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
BYTES_PER_CHUNK,
|
BYTES_PER_CHUNK,
|
||||||
CONF_DEBUG_RECORDING_DIR,
|
CONF_DEBUG_RECORDING_DIR,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
@ -927,3 +934,148 @@ async def test_tts_dict_preferred_format(
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sentence_trigger_overrides_conversation_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that sentence triggers are checked before the conversation agent."""
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": [
|
||||||
|
"test trigger sentence",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"set_conversation_response": "test trigger response",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test trigger sentence",
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Sentence trigger should have been handled
|
||||||
|
mock_async_converse.assert_not_called()
|
||||||
|
|
||||||
|
# Verify sentence trigger response
|
||||||
|
intent_end_event = next(
|
||||||
|
(
|
||||||
|
e
|
||||||
|
for e in events
|
||||||
|
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert (intent_end_event is not None) and intent_end_event.data
|
||||||
|
assert (
|
||||||
|
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||||
|
"speech"
|
||||||
|
]
|
||||||
|
== "test trigger response"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_prefer_local_intents(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the default agent is checked first when local intents are preferred."""
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
# Reuse custom sentences in test config
|
||||||
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||||
|
intent_type = "OrderBeer"
|
||||||
|
|
||||||
|
async def async_handle(
|
||||||
|
self, intent_obj: intent.Intent
|
||||||
|
) -> intent.IntentResponse:
|
||||||
|
response = intent_obj.create_response()
|
||||||
|
response.async_set_speech("Order confirmed")
|
||||||
|
return response
|
||||||
|
|
||||||
|
handler = OrderBeerIntentHandler()
|
||||||
|
intent.async_register(hass, handler)
|
||||||
|
|
||||||
|
# Fake a test agent and prefer local intents
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
await assist_pipeline.pipeline.async_update_pipeline(
|
||||||
|
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
|
||||||
|
)
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="I'd like to order a stout please",
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure prepare succeeds
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
|
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||||
|
):
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Test agent should not have been called
|
||||||
|
mock_async_converse.assert_not_called()
|
||||||
|
|
||||||
|
# Verify local intent response
|
||||||
|
intent_end_event = next(
|
||||||
|
(
|
||||||
|
e
|
||||||
|
for e in events
|
||||||
|
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert (intent_end_event is not None) and intent_end_event.data
|
||||||
|
assert (
|
||||||
|
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||||
|
"speech"
|
||||||
|
]
|
||||||
|
== "Order confirmed"
|
||||||
|
)
|
||||||
|
|
|
@ -574,6 +574,7 @@ async def test_update_pipeline(
|
||||||
"tts_voice": "test_voice",
|
"tts_voice": "test_voice",
|
||||||
"wake_word_entity": "wake_work.test_1",
|
"wake_word_entity": "wake_work.test_1",
|
||||||
"wake_word_id": "wake_word_id_1",
|
"wake_word_id": "wake_word_id_1",
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
await async_update_pipeline(
|
await async_update_pipeline(
|
||||||
|
@ -617,6 +618,7 @@ async def test_update_pipeline(
|
||||||
"tts_voice": "test_voice",
|
"tts_voice": "test_voice",
|
||||||
"wake_word_entity": "wake_work.test_1",
|
"wake_word_entity": "wake_work.test_1",
|
||||||
"wake_word_id": "wake_word_id_1",
|
"wake_word_id": "wake_word_id_1",
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -974,6 +974,7 @@ async def test_add_pipeline(
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
"wake_word_entity": "wakeword_entity_1",
|
"wake_word_entity": "wakeword_entity_1",
|
||||||
"wake_word_id": "wakeword_id_1",
|
"wake_word_id": "wakeword_id_1",
|
||||||
|
"prefer_local_intents": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
@ -991,6 +992,7 @@ async def test_add_pipeline(
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
"wake_word_entity": "wakeword_entity_1",
|
"wake_word_entity": "wakeword_entity_1",
|
||||||
"wake_word_id": "wakeword_id_1",
|
"wake_word_id": "wakeword_id_1",
|
||||||
|
"prefer_local_intents": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 2
|
assert len(pipeline_store.data) == 2
|
||||||
|
@ -1008,6 +1010,7 @@ async def test_add_pipeline(
|
||||||
tts_voice="Arnold Schwarzenegger",
|
tts_voice="Arnold Schwarzenegger",
|
||||||
wake_word_entity="wakeword_entity_1",
|
wake_word_entity="wakeword_entity_1",
|
||||||
wake_word_id="wakeword_id_1",
|
wake_word_id="wakeword_id_1",
|
||||||
|
prefer_local_intents=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
|
@ -1195,6 +1198,7 @@ async def test_get_pipeline(
|
||||||
"tts_voice": "james_earl_jones",
|
"tts_voice": "james_earl_jones",
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get conversation agent as pipeline
|
# Get conversation agent as pipeline
|
||||||
|
@ -1220,6 +1224,7 @@ async def test_get_pipeline(
|
||||||
"tts_voice": "james_earl_jones",
|
"tts_voice": "james_earl_jones",
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
|
@ -1249,6 +1254,7 @@ async def test_get_pipeline(
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
"wake_word_entity": "wakeword_entity_1",
|
"wake_word_entity": "wakeword_entity_1",
|
||||||
"wake_word_id": "wakeword_id_1",
|
"wake_word_id": "wakeword_id_1",
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
@ -1277,6 +1283,7 @@ async def test_get_pipeline(
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
"wake_word_entity": "wakeword_entity_1",
|
"wake_word_entity": "wakeword_entity_1",
|
||||||
"wake_word_id": "wakeword_id_1",
|
"wake_word_id": "wakeword_id_1",
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1304,6 +1311,7 @@ async def test_list_pipelines(
|
||||||
"tts_voice": "james_earl_jones",
|
"tts_voice": "james_earl_jones",
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"preferred_pipeline": ANY,
|
"preferred_pipeline": ANY,
|
||||||
|
@ -1395,6 +1403,7 @@ async def test_update_pipeline(
|
||||||
"tts_voice": "new_tts_voice",
|
"tts_voice": "new_tts_voice",
|
||||||
"wake_word_entity": "new_wakeword_entity",
|
"wake_word_entity": "new_wakeword_entity",
|
||||||
"wake_word_id": "new_wakeword_id",
|
"wake_word_id": "new_wakeword_id",
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 2
|
assert len(pipeline_store.data) == 2
|
||||||
|
@ -1446,6 +1455,7 @@ async def test_update_pipeline(
|
||||||
"tts_voice": None,
|
"tts_voice": None,
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
pipeline = pipeline_store.data[pipeline_id]
|
pipeline = pipeline_store.data[pipeline_id]
|
||||||
|
|
|
@ -35,6 +35,7 @@ PIPELINE_DATA = {
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"conversation_engine": "conversation_engine_2",
|
"conversation_engine": "conversation_engine_2",
|
||||||
|
@ -49,6 +50,7 @@ PIPELINE_DATA = {
|
||||||
"tts_voice": "The Voice",
|
"tts_voice": "The Voice",
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"conversation_engine": "conversation_engine_3",
|
"conversation_engine": "conversation_engine_3",
|
||||||
|
@ -63,6 +65,7 @@ PIPELINE_DATA = {
|
||||||
"tts_voice": None,
|
"tts_voice": None,
|
||||||
"wake_word_entity": None,
|
"wake_word_entity": None,
|
||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
|
"prefer_local_intents": False,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||||
|
|
|
@ -8,10 +8,15 @@ from syrupy.assertion import SnapshotAssertion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import default_agent
|
from homeassistant.components.conversation import (
|
||||||
|
ConversationInput,
|
||||||
|
async_handle_intents,
|
||||||
|
async_handle_sentence_triggers,
|
||||||
|
default_agent,
|
||||||
|
)
|
||||||
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
|
||||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
@ -229,3 +234,93 @@ async def test_prepare_agent(
|
||||||
await conversation.async_prepare_agent(hass, agent_id, "en")
|
await conversation.async_prepare_agent(hass, agent_id, "en")
|
||||||
|
|
||||||
assert len(mock_prepare.mock_calls) == 1
|
assert len(mock_prepare.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_handle_sentence_triggers(hass: HomeAssistant) -> None:
|
||||||
|
"""Test handling sentence triggers with async_handle_sentence_triggers."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
|
response_template = "response {{ trigger.device_id }}"
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": ["my trigger"],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"set_conversation_response": response_template,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Device id will be available in response template
|
||||||
|
device_id = "1234"
|
||||||
|
expected_response = f"response {device_id}"
|
||||||
|
actual_response = await async_handle_sentence_triggers(
|
||||||
|
hass,
|
||||||
|
ConversationInput(
|
||||||
|
text="my trigger",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=device_id,
|
||||||
|
language=hass.config.language,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert actual_response == expected_response
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_handle_intents(hass: HomeAssistant) -> None:
|
||||||
|
"""Test handling registered intents with async_handle_intents."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
|
# Reuse custom sentences in test config to trigger default agent.
|
||||||
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||||
|
intent_type = "OrderBeer"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.was_handled = False
|
||||||
|
|
||||||
|
async def async_handle(
|
||||||
|
self, intent_obj: intent.Intent
|
||||||
|
) -> intent.IntentResponse:
|
||||||
|
self.was_handled = True
|
||||||
|
return intent_obj.create_response()
|
||||||
|
|
||||||
|
handler = OrderBeerIntentHandler()
|
||||||
|
intent.async_register(hass, handler)
|
||||||
|
|
||||||
|
# Registered intent will be handled
|
||||||
|
result = await async_handle_intents(
|
||||||
|
hass,
|
||||||
|
ConversationInput(
|
||||||
|
text="I'd like to order a stout",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
|
language=hass.config.language,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.intent is not None
|
||||||
|
assert result.intent.intent_type == handler.intent_type
|
||||||
|
assert handler.was_handled
|
||||||
|
|
||||||
|
# No error messages, just None as a result
|
||||||
|
result = await async_handle_intents(
|
||||||
|
hass,
|
||||||
|
ConversationInput(
|
||||||
|
text="this sentence does not exist",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
|
language=hass.config.language,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
Loading…
Add table
Reference in a new issue