Raise error instead of crashing when template passed to call service target (#47467)
This commit is contained in:
parent
8f31b09b55
commit
4c181bbfe5
2 changed files with 34 additions and 22 deletions
|
@ -13,10 +13,9 @@ from homeassistant.exceptions import (
|
||||||
TemplateError,
|
TemplateError,
|
||||||
Unauthorized,
|
Unauthorized,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers import config_validation as cv, entity
|
from homeassistant.helpers import config_validation as cv, entity, template
|
||||||
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
|
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.helpers.template import Template
|
|
||||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||||
|
|
||||||
from . import const, decorators, messages
|
from . import const, decorators, messages
|
||||||
|
@ -132,6 +131,11 @@ async def handle_call_service(hass, connection, msg):
|
||||||
if msg["domain"] == HASS_DOMAIN and msg["service"] in ["restart", "stop"]:
|
if msg["domain"] == HASS_DOMAIN and msg["service"] in ["restart", "stop"]:
|
||||||
blocking = False
|
blocking = False
|
||||||
|
|
||||||
|
# We do not support templates.
|
||||||
|
target = msg.get("target")
|
||||||
|
if template.is_complex(target):
|
||||||
|
raise vol.Invalid("Templates are not supported here")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = connection.context(msg)
|
context = connection.context(msg)
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
|
@ -140,7 +144,7 @@ async def handle_call_service(hass, connection, msg):
|
||||||
msg.get("service_data"),
|
msg.get("service_data"),
|
||||||
blocking,
|
blocking,
|
||||||
context,
|
context,
|
||||||
target=msg.get("target"),
|
target=target,
|
||||||
)
|
)
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
messages.result_message(msg["id"], {"context": context})
|
messages.result_message(msg["id"], {"context": context})
|
||||||
|
@ -256,14 +260,14 @@ def handle_ping(hass, connection, msg):
|
||||||
async def handle_render_template(hass, connection, msg):
|
async def handle_render_template(hass, connection, msg):
|
||||||
"""Handle render_template command."""
|
"""Handle render_template command."""
|
||||||
template_str = msg["template"]
|
template_str = msg["template"]
|
||||||
template = Template(template_str, hass)
|
template_obj = template.Template(template_str, hass)
|
||||||
variables = msg.get("variables")
|
variables = msg.get("variables")
|
||||||
timeout = msg.get("timeout")
|
timeout = msg.get("timeout")
|
||||||
info = None
|
info = None
|
||||||
|
|
||||||
if timeout:
|
if timeout:
|
||||||
try:
|
try:
|
||||||
timed_out = await template.async_render_will_timeout(timeout)
|
timed_out = await template_obj.async_render_will_timeout(timeout)
|
||||||
except TemplateError as ex:
|
except TemplateError as ex:
|
||||||
connection.send_error(msg["id"], const.ERR_TEMPLATE_ERROR, str(ex))
|
connection.send_error(msg["id"], const.ERR_TEMPLATE_ERROR, str(ex))
|
||||||
return
|
return
|
||||||
|
@ -294,7 +298,7 @@ async def handle_render_template(hass, connection, msg):
|
||||||
try:
|
try:
|
||||||
info = async_track_template_result(
|
info = async_track_template_result(
|
||||||
hass,
|
hass,
|
||||||
[TrackTemplate(template, variables)],
|
[TrackTemplate(template_obj, variables)],
|
||||||
_template_listener,
|
_template_listener,
|
||||||
raise_on_template_error=True,
|
raise_on_template_error=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,13 +21,7 @@ from tests.common import MockEntity, MockEntityPlatform, async_mock_service
|
||||||
|
|
||||||
async def test_call_service(hass, websocket_client):
|
async def test_call_service(hass, websocket_client):
|
||||||
"""Test call service command."""
|
"""Test call service command."""
|
||||||
calls = []
|
calls = async_mock_service(hass, "domain_test", "test_service")
|
||||||
|
|
||||||
@callback
|
|
||||||
def service_call(call):
|
|
||||||
calls.append(call)
|
|
||||||
|
|
||||||
hass.services.async_register("domain_test", "test_service", service_call)
|
|
||||||
|
|
||||||
await websocket_client.send_json(
|
await websocket_client.send_json(
|
||||||
{
|
{
|
||||||
|
@ -54,13 +48,7 @@ async def test_call_service(hass, websocket_client):
|
||||||
|
|
||||||
async def test_call_service_target(hass, websocket_client):
|
async def test_call_service_target(hass, websocket_client):
|
||||||
"""Test call service command with target."""
|
"""Test call service command with target."""
|
||||||
calls = []
|
calls = async_mock_service(hass, "domain_test", "test_service")
|
||||||
|
|
||||||
@callback
|
|
||||||
def service_call(call):
|
|
||||||
calls.append(call)
|
|
||||||
|
|
||||||
hass.services.async_register("domain_test", "test_service", service_call)
|
|
||||||
|
|
||||||
await websocket_client.send_json(
|
await websocket_client.send_json(
|
||||||
{
|
{
|
||||||
|
@ -93,6 +81,28 @@ async def test_call_service_target(hass, websocket_client):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_service_target_template(hass, websocket_client):
|
||||||
|
"""Test call service command with target does not allow template."""
|
||||||
|
await websocket_client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "call_service",
|
||||||
|
"domain": "domain_test",
|
||||||
|
"service": "test_service",
|
||||||
|
"service_data": {"hello": "world"},
|
||||||
|
"target": {
|
||||||
|
"entity_id": "{{ 1 }}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 5
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == const.ERR_INVALID_FORMAT
|
||||||
|
|
||||||
|
|
||||||
async def test_call_service_not_found(hass, websocket_client):
|
async def test_call_service_not_found(hass, websocket_client):
|
||||||
"""Test call service command."""
|
"""Test call service command."""
|
||||||
await websocket_client.send_json(
|
await websocket_client.send_json(
|
||||||
|
@ -232,7 +242,6 @@ async def test_call_service_error(hass, websocket_client):
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = await websocket_client.receive_json()
|
msg = await websocket_client.receive_json()
|
||||||
print(msg)
|
|
||||||
assert msg["id"] == 5
|
assert msg["id"] == 5
|
||||||
assert msg["type"] == const.TYPE_RESULT
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
assert msg["success"] is False
|
assert msg["success"] is False
|
||||||
|
@ -249,7 +258,6 @@ async def test_call_service_error(hass, websocket_client):
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = await websocket_client.receive_json()
|
msg = await websocket_client.receive_json()
|
||||||
print(msg)
|
|
||||||
assert msg["id"] == 6
|
assert msg["id"] == 6
|
||||||
assert msg["type"] == const.TYPE_RESULT
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
assert msg["success"] is False
|
assert msg["success"] is False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue