Raise error instead of crashing when template passed to call service target (#47467)

This commit is contained in:
Paulus Schoutsen 2021-03-05 15:34:18 -08:00 committed by GitHub
parent 8f31b09b55
commit 4c181bbfe5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 22 deletions

View file

@ -13,10 +13,9 @@ from homeassistant.exceptions import (
TemplateError, TemplateError,
Unauthorized, Unauthorized,
) )
from homeassistant.helpers import config_validation as cv, entity from homeassistant.helpers import config_validation as cv, entity, template
from homeassistant.helpers.event import TrackTemplate, async_track_template_result from homeassistant.helpers.event import TrackTemplate, async_track_template_result
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.template import Template
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
from . import const, decorators, messages from . import const, decorators, messages
@ -132,6 +131,11 @@ async def handle_call_service(hass, connection, msg):
if msg["domain"] == HASS_DOMAIN and msg["service"] in ["restart", "stop"]: if msg["domain"] == HASS_DOMAIN and msg["service"] in ["restart", "stop"]:
blocking = False blocking = False
# We do not support templates.
target = msg.get("target")
if template.is_complex(target):
raise vol.Invalid("Templates are not supported here")
try: try:
context = connection.context(msg) context = connection.context(msg)
await hass.services.async_call( await hass.services.async_call(
@ -140,7 +144,7 @@ async def handle_call_service(hass, connection, msg):
msg.get("service_data"), msg.get("service_data"),
blocking, blocking,
context, context,
target=msg.get("target"), target=target,
) )
connection.send_message( connection.send_message(
messages.result_message(msg["id"], {"context": context}) messages.result_message(msg["id"], {"context": context})
@ -256,14 +260,14 @@ def handle_ping(hass, connection, msg):
async def handle_render_template(hass, connection, msg): async def handle_render_template(hass, connection, msg):
"""Handle render_template command.""" """Handle render_template command."""
template_str = msg["template"] template_str = msg["template"]
template = Template(template_str, hass) template_obj = template.Template(template_str, hass)
variables = msg.get("variables") variables = msg.get("variables")
timeout = msg.get("timeout") timeout = msg.get("timeout")
info = None info = None
if timeout: if timeout:
try: try:
timed_out = await template.async_render_will_timeout(timeout) timed_out = await template_obj.async_render_will_timeout(timeout)
except TemplateError as ex: except TemplateError as ex:
connection.send_error(msg["id"], const.ERR_TEMPLATE_ERROR, str(ex)) connection.send_error(msg["id"], const.ERR_TEMPLATE_ERROR, str(ex))
return return
@ -294,7 +298,7 @@ async def handle_render_template(hass, connection, msg):
try: try:
info = async_track_template_result( info = async_track_template_result(
hass, hass,
[TrackTemplate(template, variables)], [TrackTemplate(template_obj, variables)],
_template_listener, _template_listener,
raise_on_template_error=True, raise_on_template_error=True,
) )

View file

@ -21,13 +21,7 @@ from tests.common import MockEntity, MockEntityPlatform, async_mock_service
async def test_call_service(hass, websocket_client): async def test_call_service(hass, websocket_client):
"""Test call service command.""" """Test call service command."""
calls = [] calls = async_mock_service(hass, "domain_test", "test_service")
@callback
def service_call(call):
calls.append(call)
hass.services.async_register("domain_test", "test_service", service_call)
await websocket_client.send_json( await websocket_client.send_json(
{ {
@ -54,13 +48,7 @@ async def test_call_service(hass, websocket_client):
async def test_call_service_target(hass, websocket_client): async def test_call_service_target(hass, websocket_client):
"""Test call service command with target.""" """Test call service command with target."""
calls = [] calls = async_mock_service(hass, "domain_test", "test_service")
@callback
def service_call(call):
calls.append(call)
hass.services.async_register("domain_test", "test_service", service_call)
await websocket_client.send_json( await websocket_client.send_json(
{ {
@ -93,6 +81,28 @@ async def test_call_service_target(hass, websocket_client):
} }
async def test_call_service_target_template(hass, websocket_client):
"""Test call service command with target does not allow template."""
await websocket_client.send_json(
{
"id": 5,
"type": "call_service",
"domain": "domain_test",
"service": "test_service",
"service_data": {"hello": "world"},
"target": {
"entity_id": "{{ 1 }}",
},
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_INVALID_FORMAT
async def test_call_service_not_found(hass, websocket_client): async def test_call_service_not_found(hass, websocket_client):
"""Test call service command.""" """Test call service command."""
await websocket_client.send_json( await websocket_client.send_json(
@ -232,7 +242,6 @@ async def test_call_service_error(hass, websocket_client):
) )
msg = await websocket_client.receive_json() msg = await websocket_client.receive_json()
print(msg)
assert msg["id"] == 5 assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT assert msg["type"] == const.TYPE_RESULT
assert msg["success"] is False assert msg["success"] is False
@ -249,7 +258,6 @@ async def test_call_service_error(hass, websocket_client):
) )
msg = await websocket_client.receive_json() msg = await websocket_client.receive_json()
print(msg)
assert msg["id"] == 6 assert msg["id"] == 6
assert msg["type"] == const.TYPE_RESULT assert msg["type"] == const.TYPE_RESULT
assert msg["success"] is False assert msg["success"] is False