From e08ee282ab6e0a913545a81bf85098c2df2adbec Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Sep 2020 07:43:22 -0500 Subject: [PATCH] Abort execution of template renders that overwhelm the system (#40647) --- .../components/websocket_api/commands.py | 14 ++++- homeassistant/helpers/template.py | 50 +++++++++++++++++ homeassistant/util/thread.py | 33 +++++++++++ .../components/websocket_api/test_commands.py | 47 +++++++++++----- tests/helpers/test_template.py | 28 ++++++++++ tests/util/test_thread.py | 55 +++++++++++++++++++ 6 files changed, 212 insertions(+), 15 deletions(-) create mode 100644 tests/util/test_thread.py diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index d80c7934dd4..11d97f58f50 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -239,22 +239,32 @@ def handle_ping(hass, connection, msg): connection.send_message(pong_message(msg["id"])) -@callback @decorators.websocket_command( { vol.Required("type"): "render_template", vol.Required("template"): str, vol.Optional("entity_ids"): cv.entity_ids, vol.Optional("variables"): dict, + vol.Optional("timeout"): vol.Coerce(float), } ) -def handle_render_template(hass, connection, msg): +@decorators.async_response +async def handle_render_template(hass, connection, msg): """Handle render_template command.""" template_str = msg["template"] template = Template(template_str, hass) variables = msg.get("variables") + timeout = msg.get("timeout") info = None + if timeout and await template.async_render_will_timeout(timeout): + connection.send_error( + msg["id"], + const.ERR_TEMPLATE_ERROR, + f"Exceeded maximum execution time of {timeout}s", + ) + return + @callback def _template_listener(event, updates): nonlocal info diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 5564024a92b..721c1407f37 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -1,4 +1,5 @@ """Template helper methods for rendering strings with Home Assistant data.""" +import asyncio import base64 import collections.abc from datetime import datetime, timedelta @@ -36,6 +37,7 @@ from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType from homeassistant.loader import bind_hass from homeassistant.util import convert, dt as dt_util, location as loc_util from homeassistant.util.async_ import run_callback_threadsafe +from homeassistant.util.thread import ThreadWithException # mypy: allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs, no-warn-return-any @@ -309,6 +311,54 @@ class Template: except jinja2.TemplateError as err: raise TemplateError(err) from err + async def async_render_will_timeout( + self, timeout: float, variables: TemplateVarsType = None, **kwargs: Any + ) -> bool: + """Check to see if rendering a template will timeout during render. + + This is intended to check for expensive templates + that will make the system unstable. The template + is rendered in the executor to ensure it does not + tie up the event loop. + + This function is not a security control and is only + intended to be used as a safety check when testing + templates. + + This method must be run in the event loop. + """ + assert self.hass + + if self.is_static: + return False + + compiled = self._compiled or self._ensure_compiled() + + if variables is not None: + kwargs.update(variables) + + finish_event = asyncio.Event() + + def _render_template(): + try: + compiled.render(kwargs) + except TimeoutError: + pass + finally: + run_callback_threadsafe(self.hass.loop, finish_event.set) + + try: + template_render_thread = ThreadWithException(target=_render_template) + template_render_thread.start() + await asyncio.wait_for(finish_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + template_render_thread.raise_exc(TimeoutError) + return True + finally: + template_render_thread.join() + + return False + @callback def async_render_to_info( self, variables: TemplateVarsType = None, **kwargs: Any diff --git a/homeassistant/util/thread.py b/homeassistant/util/thread.py index e5654e6f8c6..bf61c67172a 100644 --- a/homeassistant/util/thread.py +++ b/homeassistant/util/thread.py @@ -1,4 +1,6 @@ """Threading util helpers.""" +import ctypes +import inspect import sys import threading from typing import Any @@ -24,3 +26,34 @@ def fix_threading_exception_logging() -> None: sys.excepthook(*sys.exc_info()) threading.Thread.run = run # type: ignore + + +def _async_raise(tid: int, exctype: Any) -> None: + """Raise an exception in the threads with id tid.""" + if not inspect.isclass(exctype): + raise TypeError("Only types can be raised (not instances)") + + c_tid = ctypes.c_long(tid) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, ctypes.py_object(exctype)) + + if res == 1: + return + + # "if it returns a number greater than one, you're in trouble, + # and you should call it again with exc=NULL to revert the effect" + ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, None) + raise SystemError("PyThreadState_SetAsyncExc failed") + + +class ThreadWithException(threading.Thread): + """A thread class that supports raising exception in the thread from another thread. + + Based on + https://stackoverflow.com/questions/323972/is-there-any-way-to-kill-a-thread/49877671 + + """ + + def raise_exc(self, exctype: Any) -> None: + """Raise the given exception type in the context of this thread.""" + assert self.ident + _async_raise(self.ident, exctype) diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index ea6f2f42bdc..3969ff90706 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -397,9 +397,7 @@ async def test_subscribe_unsubscribe_events_state_changed( assert msg["event"]["data"]["entity_id"] == "light.permitted" -async def test_render_template_renders_template( - hass, websocket_client, hass_admin_user -): +async def test_render_template_renders_template(hass, websocket_client): """Test simple template is rendered and updated.""" hass.states.async_set("light.test", "on") @@ -437,7 +435,7 @@ async def test_render_template_renders_template( async def test_render_template_manual_entity_ids_no_longer_needed( - hass, websocket_client, hass_admin_user + hass, websocket_client ): """Test that updates to specified entity ids cause a template rerender.""" hass.states.async_set("light.test", "on") @@ -475,9 +473,7 @@ async def test_render_template_manual_entity_ids_no_longer_needed( } -async def test_render_template_with_error( - hass, websocket_client, hass_admin_user, caplog -): +async def test_render_template_with_error(hass, websocket_client, caplog): """Test a template with an error.""" await websocket_client.send_json( {"id": 5, "type": "render_template", "template": "{{ my_unknown_var() + 1 }}"} @@ -492,9 +488,7 @@ async def test_render_template_with_error( assert "TemplateError" not in caplog.text -async def test_render_template_with_delayed_error( - hass, websocket_client, hass_admin_user, caplog -): +async def test_render_template_with_delayed_error(hass, websocket_client, caplog): """Test a template with an error that only happens after a state change.""" hass.states.async_set("sensor.test", "on") await hass.async_block_till_done() @@ -539,9 +533,36 @@ async def test_render_template_with_delayed_error( assert "TemplateError" not in caplog.text -async def test_render_template_returns_with_match_all( - hass, websocket_client, hass_admin_user -): +async def test_render_template_with_timeout(hass, websocket_client, caplog): + """Test a template that will timeout.""" + + slow_template_str = """ +{% for var in range(1000) -%} + {% for var in range(1000) -%} + {{ var }} + {%- endfor %} +{%- endfor %} +""" + + await websocket_client.send_json( + { + "id": 5, + "type": "render_template", + "timeout": 0.000001, + "template": slow_template_str, + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 5 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_TEMPLATE_ERROR + + assert "TemplateError" not in caplog.text + + +async def test_render_template_returns_with_match_all(hass, websocket_client): """Test that a template that would match with all entities still return success.""" await websocket_client.send_json( {"id": 5, "type": "render_template", "template": "State is: {{ 42 }}"} diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index 7cfdd4241b7..be6b1bd2ecf 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -2455,3 +2455,31 @@ async def test_lifecycle(hass): assert info.filter("sensor.sensor1") is False assert info.filter_lifecycle("sensor.new") is True assert info.filter_lifecycle("sensor.removed") is True + + +async def test_template_timeout(hass): + """Test to see if a template will timeout.""" + for i in range(2): + hass.states.async_set(f"sensor.sensor{i}", "on") + + tmp = template.Template("{{ states | count }}", hass) + assert await tmp.async_render_will_timeout(3) is False + + tmp2 = template.Template("{{ error_invalid + 1 }}", hass) + assert await tmp2.async_render_will_timeout(3) is False + + tmp3 = template.Template("static", hass) + assert await tmp3.async_render_will_timeout(3) is False + + tmp4 = template.Template("{{ var1 }}", hass) + assert await tmp4.async_render_will_timeout(3, {"var1": "ok"}) is False + + slow_template_str = """ +{% for var in range(1000) -%} + {% for var in range(1000) -%} + {{ var }} + {%- endfor %} +{%- endfor %} +""" + tmp5 = template.Template(slow_template_str, hass) + assert await tmp5.async_render_will_timeout(0.000001) is True diff --git a/tests/util/test_thread.py b/tests/util/test_thread.py new file mode 100644 index 00000000000..d5f05f5c93e --- /dev/null +++ b/tests/util/test_thread.py @@ -0,0 +1,55 @@ +"""Test Home Assistant thread utils.""" + +import asyncio + +import pytest + +from homeassistant.util.async_ import run_callback_threadsafe +from homeassistant.util.thread import ThreadWithException + + +async def test_thread_with_exception_invalid(hass): + """Test throwing an invalid thread exception.""" + + finish_event = asyncio.Event() + + def _do_nothing(*_): + run_callback_threadsafe(hass.loop, finish_event.set) + + test_thread = ThreadWithException(target=_do_nothing) + test_thread.start() + await asyncio.wait_for(finish_event.wait(), timeout=0.1) + + with pytest.raises(TypeError): + test_thread.raise_exc(_EmptyClass()) + test_thread.join() + + +async def test_thread_not_started(hass): + """Test throwing when the thread is not started.""" + + test_thread = ThreadWithException(target=lambda *_: None) + + with pytest.raises(AssertionError): + test_thread.raise_exc(TimeoutError) + + +async def test_thread_fails_raise(hass): + """Test throwing after already ended.""" + + finish_event = asyncio.Event() + + def _do_nothing(*_): + run_callback_threadsafe(hass.loop, finish_event.set) + + test_thread = ThreadWithException(target=_do_nothing) + test_thread.start() + await asyncio.wait_for(finish_event.wait(), timeout=0.1) + test_thread.join() + + with pytest.raises(SystemError): + test_thread.raise_exc(ValueError) + + +class _EmptyClass: + """An empty class."""