Add set_conversation_response script action (#108233)

* Add set_conversation_response script action

* Update homeassistant/components/conversation/trigger.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Revert accidental change

* Add test

* Ignore mypy

* Remove incorrect callback decorator

* Update homeassistant/helpers/script.py

* Add test with templated set_conversation_response

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Erik Montnemery 2024-01-23 15:13:42 +01:00 committed by GitHub
parent e3a73c12bc
commit 9bff039d17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 248 additions and 13 deletions

View file

@ -72,6 +72,7 @@ from homeassistant.helpers.script import (
CONF_MAX,
CONF_MAX_EXCEEDED,
Script,
ScriptRunResult,
script_stack_cv,
)
from homeassistant.helpers.script_variables import ScriptVariables
@ -359,7 +360,7 @@ class BaseAutomationEntity(ToggleEntity, ABC):
run_variables: dict[str, Any],
context: Context | None = None,
skip_condition: bool = False,
) -> None:
) -> ScriptRunResult | None:
"""Trigger automation."""
@ -581,7 +582,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity):
run_variables: dict[str, Any],
context: Context | None = None,
skip_condition: bool = False,
) -> None:
) -> ScriptRunResult | None:
"""Trigger automation.
This method is a coroutine.
@ -617,7 +618,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity):
except TemplateError as err:
self._logger.error("Error rendering variables: %s", err)
automation_trace.set_error(err)
return
return None
# Prepare tracing the automation
automation_trace.set_trace(trace_get())
@ -644,7 +645,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity):
trace_get(clear=False),
)
script_execution_set("failed_conditions")
return
return None
self.async_set_context(trigger_context)
event_data = {
@ -666,7 +667,7 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity):
try:
with trace_path("action"):
await self.action_script.async_run(
return await self.action_script.async_run(
variables, trigger_context, started_action
)
except ServiceNotFound as err:
@ -697,6 +698,8 @@ class AutomationEntity(BaseAutomationEntity, RestoreEntity):
self._logger.exception("While executing automation %s", self.entity_id)
automation_trace.set_error(err)
return None
async def async_will_remove_from_hass(self) -> None:
"""Remove listeners when removing automation from Home Assistant."""
await super().async_will_remove_from_hass()

View file

@ -7,10 +7,11 @@ from hassil.recognize import PUNCTUATION, RecognizeResult
import voluptuous as vol
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import UNDEFINED, ConfigType
from . import HOME_ASSISTANT_AGENT, _get_agent_manager
from .const import DOMAIN
@ -60,7 +61,6 @@ async def async_attach_trigger(
job = HassJob(action)
@callback
async def call_action(sentence: str, result: RecognizeResult) -> str | None:
"""Call action with right context."""
@ -91,7 +91,12 @@ async def async_attach_trigger(
job,
{"trigger": trigger_input},
):
await future
automation_result = await future
if isinstance(
automation_result, ScriptRunResult
) and automation_result.conversation_response not in (None, UNDEFINED):
# mypy does not understand the type narrowing, unclear why
return automation_result.conversation_response # type: ignore[return-value]
return "Done"

View file

@ -251,6 +251,7 @@ CONF_SERVICE: Final = "service"
CONF_SERVICE_DATA: Final = "data"
CONF_SERVICE_DATA_TEMPLATE: Final = "data_template"
CONF_SERVICE_TEMPLATE: Final = "service_template"
CONF_SET_CONVERSATION_RESPONSE: Final = "set_conversation_response"
CONF_SHOW_ON_MAP: Final = "show_on_map"
CONF_SLAVE: Final = "slave"
CONF_SOURCE: Final = "source"

View file

@ -67,6 +67,7 @@ from homeassistant.const import (
CONF_SERVICE_DATA,
CONF_SERVICE_DATA_TEMPLATE,
CONF_SERVICE_TEMPLATE,
CONF_SET_CONVERSATION_RESPONSE,
CONF_STATE,
CONF_STOP,
CONF_TARGET,
@ -1267,6 +1268,9 @@ def make_entity_service_schema(
)
SCRIPT_CONVERSATION_RESPONSE_SCHEMA = vol.Any(template, None)
SCRIPT_VARIABLES_SCHEMA = vol.All(
vol.Schema({str: template_complex}),
# pylint: disable-next=unnecessary-lambda
@ -1742,6 +1746,15 @@ _SCRIPT_SET_SCHEMA = vol.Schema(
}
)
_SCRIPT_SET_CONVERSATION_RESPONSE_SCHEMA = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(
CONF_SET_CONVERSATION_RESPONSE
): SCRIPT_CONVERSATION_RESPONSE_SCHEMA,
}
)
_SCRIPT_STOP_SCHEMA = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
@ -1794,6 +1807,7 @@ SCRIPT_ACTION_VARIABLES = "variables"
SCRIPT_ACTION_STOP = "stop"
SCRIPT_ACTION_IF = "if"
SCRIPT_ACTION_PARALLEL = "parallel"
SCRIPT_ACTION_SET_CONVERSATION_RESPONSE = "set_conversation_response"
def determine_script_action(action: dict[str, Any]) -> str:
@ -1840,6 +1854,9 @@ def determine_script_action(action: dict[str, Any]) -> str:
if CONF_PARALLEL in action:
return SCRIPT_ACTION_PARALLEL
if CONF_SET_CONVERSATION_RESPONSE in action:
return SCRIPT_ACTION_SET_CONVERSATION_RESPONSE
raise ValueError("Unable to determine action")
@ -1858,6 +1875,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA,
SCRIPT_ACTION_IF: _SCRIPT_IF_SCHEMA,
SCRIPT_ACTION_PARALLEL: _SCRIPT_PARALLEL_SCHEMA,
SCRIPT_ACTION_SET_CONVERSATION_RESPONSE: _SCRIPT_SET_CONVERSATION_RESPONSE_SCHEMA,
}

View file

@ -52,6 +52,7 @@ from homeassistant.const import (
CONF_SERVICE,
CONF_SERVICE_DATA,
CONF_SERVICE_DATA_TEMPLATE,
CONF_SET_CONVERSATION_RESPONSE,
CONF_STOP,
CONF_TARGET,
CONF_THEN,
@ -98,7 +99,7 @@ from .trace import (
trace_update_result,
)
from .trigger import async_initialize_triggers, async_validate_trigger_config
from .typing import ConfigType
from .typing import UNDEFINED, ConfigType, UndefinedType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
@ -259,6 +260,7 @@ STATIC_VALIDATION_ACTION_TYPES = (
cv.SCRIPT_ACTION_ACTIVATE_SCENE,
cv.SCRIPT_ACTION_VARIABLES,
cv.SCRIPT_ACTION_STOP,
cv.SCRIPT_ACTION_SET_CONVERSATION_RESPONSE,
)
@ -385,6 +387,7 @@ class _ScriptRun:
self._step = -1
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
self._conversation_response: str | None | UndefinedType = UNDEFINED
def _changed(self) -> None:
if not self._stop.is_set():
@ -450,7 +453,7 @@ class _ScriptRun:
script_stack.pop()
self._finish()
return ScriptRunResult(response, self._variables)
return ScriptRunResult(self._conversation_response, response, self._variables)
async def _async_step(self, log_exceptions: bool) -> None:
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
@ -1031,6 +1034,18 @@ class _ScriptRun:
self._hass, self._variables, render_as_defaults=False
)
async def _async_set_conversation_response_step(self):
"""Set conversation response."""
self._step_log("setting conversation response")
resp: template.Template | None = self._action[CONF_SET_CONVERSATION_RESPONSE]
if resp is None:
self._conversation_response = None
else:
self._conversation_response = resp.async_render(
variables=self._variables, parse_result=False
)
trace_set_result(conversation_response=self._conversation_response)
async def _async_stop_step(self):
"""Stop script execution."""
stop = self._action[CONF_STOP]
@ -1075,11 +1090,13 @@ class _ScriptRun:
async def _async_run_script(self, script: Script) -> None:
"""Execute a script."""
await self._async_run_long_action(
result = await self._async_run_long_action(
self._hass.async_create_task(
script.async_run(self._variables, self._context)
)
)
if result and result.conversation_response is not UNDEFINED:
self._conversation_response = result.conversation_response
class _QueuedScriptRun(_ScriptRun):
@ -1202,6 +1219,7 @@ class _IfData(TypedDict):
class ScriptRunResult:
"""Container with the result of a script run."""
conversation_response: str | None | UndefinedType
service_response: ServiceResponse
variables: dict

View file

@ -73,7 +73,7 @@ class TriggerActionType(Protocol):
self,
run_variables: dict[str, Any],
context: Context | None = None,
) -> None:
) -> Any:
"""Define action callback type."""

View file

@ -68,6 +68,37 @@ async def test_if_fires_on_event(hass: HomeAssistant, calls, setup_comp) -> None
}
async def test_response(hass: HomeAssistant, setup_comp) -> None:
"""Test the firing of events."""
response = "I'm sorry, Dave. I'm afraid I can't do that"
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["Open the pod bay door Hal"],
},
"action": {
"set_conversation_response": response,
},
}
},
)
service_response = await hass.services.async_call(
"conversation",
"process",
{
"text": "Open the pod bay door Hal",
},
blocking=True,
return_response=True,
)
assert service_response["response"]["speech"]["plain"]["speech"] == response
async def test_same_trigger_multiple_sentences(
hass: HomeAssistant, calls, setup_comp
) -> None:

View file

@ -41,6 +41,7 @@ from homeassistant.helpers import (
trace,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.typing import UNDEFINED
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -4601,6 +4602,9 @@ async def test_validate_action_config(
cv.SCRIPT_ACTION_PARALLEL: {
"parallel": [templated_device_action("parallel_event")],
},
cv.SCRIPT_ACTION_SET_CONVERSATION_RESPONSE: {
"set_conversation_response": "Hello world"
},
}
expected_templates = {
cv.SCRIPT_ACTION_CHECK_CONDITION: None,
@ -5357,3 +5361,158 @@ async def test_condition_not_shorthand(
"2": [{"result": {"event": "test_event", "event_data": {}}}],
}
)
async def test_conversation_response(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test setting conversation response."""
sequence = cv.SCRIPT_SCHEMA([{"set_conversation_response": "Testing 123"}])
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
assert result.conversation_response == "Testing 123"
assert_action_trace(
{
"0": [{"result": {"conversation_response": "Testing 123"}}],
}
)
async def test_conversation_response_template(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test a templated conversation response."""
sequence = cv.SCRIPT_SCHEMA(
[
{"variables": {"my_var": "234"}},
{"set_conversation_response": '{{ "Testing " + my_var }}'},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
assert result.conversation_response == "Testing 234"
assert_action_trace(
{
"0": [{"variables": {"my_var": "234"}}],
"1": [{"result": {"conversation_response": "Testing 234"}}],
}
)
async def test_conversation_response_not_set(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test not setting conversation response."""
sequence = cv.SCRIPT_SCHEMA([])
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
assert result.conversation_response is UNDEFINED
assert_action_trace({})
async def test_conversation_response_unset(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test clearing conversation response."""
sequence = cv.SCRIPT_SCHEMA(
[
{"set_conversation_response": "Testing 123"},
{"set_conversation_response": None},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
result = await script_obj.async_run(context=Context())
assert result.conversation_response is None
assert_action_trace(
{
"0": [{"result": {"conversation_response": "Testing 123"}}],
"1": [{"result": {"conversation_response": None}}],
}
)
@pytest.mark.parametrize(
("var", "if_result", "choice", "response"),
[(1, True, "then", "If: Then"), (2, False, "else", "If: Else")],
)
async def test_conversation_response_subscript_if(
hass: HomeAssistant,
var: int,
if_result: bool,
choice: str,
response: str,
) -> None:
"""Test setting conversation response in a subscript."""
sequence = cv.SCRIPT_SCHEMA(
[
{"set_conversation_response": "Testing 123"},
{
"if": {
"condition": "template",
"value_template": "{{ var == 1 }}",
},
"then": {"set_conversation_response": "If: Then"},
"else": {"set_conversation_response": "If: Else"},
},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
run_vars = MappingProxyType({"var": var})
result = await script_obj.async_run(run_vars, context=Context())
assert result.conversation_response == response
expected_trace = {
"0": [{"result": {"conversation_response": "Testing 123"}}],
"1": [{"result": {"choice": choice}}],
"1/if": [{"result": {"result": if_result}}],
"1/if/condition/0": [{"result": {"result": var == 1, "entities": []}}],
f"1/{choice}/0": [{"result": {"conversation_response": response}}],
}
assert_action_trace(expected_trace)
@pytest.mark.parametrize(
("var", "if_result", "choice"), [(1, True, "then"), (2, False, "else")]
)
async def test_conversation_response_not_set_subscript_if(
hass: HomeAssistant,
var: int,
if_result: bool,
choice: str,
) -> None:
"""Test not setting conversation response in a subscript."""
sequence = cv.SCRIPT_SCHEMA(
[
{"set_conversation_response": "Testing 123"},
{
"if": {
"condition": "template",
"value_template": "{{ var == 1 }}",
},
"then": [],
"else": [],
},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
run_vars = MappingProxyType({"var": var})
result = await script_obj.async_run(run_vars, context=Context())
assert result.conversation_response == "Testing 123"
expected_trace = {
"0": [{"result": {"conversation_response": "Testing 123"}}],
"1": [{"result": {"choice": choice}}],
"1/if": [{"result": {"result": if_result}}],
"1/if/condition/0": [{"result": {"result": var == 1, "entities": []}}],
}
assert_action_trace(expected_trace)