Enable augmented-assignment operations in scripts ()

This commit is contained in:
Aarni Koskela 2024-02-18 03:32:23 +02:00 committed by GitHub
parent 33ff6b5b6e
commit 5d23a1f84f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 79 additions and 0 deletions
homeassistant/components/python_script
tests/components/python_script

View file

@ -2,8 +2,11 @@
import datetime import datetime
import glob import glob
import logging import logging
from numbers import Number
import operator
import os import os
import time import time
from typing import Any
from RestrictedPython import ( from RestrictedPython import (
compile_restricted_exec, compile_restricted_exec,
@ -146,6 +149,36 @@ def discover_scripts(hass):
async_set_service_schema(hass, DOMAIN, name, service_desc) async_set_service_schema(hass, DOMAIN, name, service_desc)
IOPERATOR_TO_OPERATOR = {
"%=": operator.mod,
"&=": operator.and_,
"**=": operator.pow,
"*=": operator.mul,
"+=": operator.add,
"-=": operator.sub,
"//=": operator.floordiv,
"/=": operator.truediv,
"<<=": operator.lshift,
">>=": operator.rshift,
"@=": operator.matmul,
"^=": operator.xor,
"|=": operator.or_,
}
def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any:
"""Implement augmented-assign (+=, -=, etc.) operators for restricted code.
See RestrictedPython's `visit_AugAssign` for details.
"""
if not isinstance(target, (list, Number, str)):
raise ScriptError(f"The {op!r} operation is not allowed on a {type(target)}")
op_fun = IOPERATOR_TO_OPERATOR.get(op)
if not op_fun:
raise ScriptError(f"The {op!r} operation is not allowed")
return op_fun(target, operand)
@bind_hass @bind_hass
def execute_script(hass, name, data=None, return_response=False): def execute_script(hass, name, data=None, return_response=False):
"""Execute a script.""" """Execute a script."""
@ -223,6 +256,7 @@ def execute(hass, filename, source, data=None, return_response=False):
"_getitem_": default_guarded_getitem, "_getitem_": default_guarded_getitem,
"_iter_unpack_sequence_": guarded_iter_unpack_sequence, "_iter_unpack_sequence_": guarded_iter_unpack_sequence,
"_unpack_sequence_": guarded_unpack_sequence, "_unpack_sequence_": guarded_unpack_sequence,
"_inplacevar_": guarded_inplacevar,
"hass": hass, "hass": hass,
"data": data or {}, "data": data or {},
"logger": logger, "logger": logger,

View file

@ -596,3 +596,48 @@ output = f"hello {data.get('name', 'World')}"
blocking=True, blocking=True,
return_response=True, return_response=True,
) )
async def test_augmented_assignment_operations(hass: HomeAssistant) -> None:
"""Test that augmented assignment operations work."""
source = """
a = 10
a += 20
a *= 5
a -= 8
b = "foo"
b += "bar"
b *= 2
c = []
c += [1, 2, 3]
c *= 2
hass.states.set('hello.a', a)
hass.states.set('hello.b', b)
hass.states.set('hello.c', c)
"""
hass.async_add_executor_job(execute, hass, "aug_assign.py", source, {})
await hass.async_block_till_done()
assert hass.states.get("hello.a").state == str(((10 + 20) * 5) - 8)
assert hass.states.get("hello.b").state == ("foo" + "bar") * 2
assert hass.states.get("hello.c").state == str([1, 2, 3] * 2)
@pytest.mark.parametrize(
("case", "error"),
[
pytest.param(
"d = datetime.date(2024, 1, 1); d += 5",
"The '+=' operation is not allowed",
id="datetime.date",
),
],
)
async def test_prohibited_augmented_assignment_operations(
hass: HomeAssistant, case: str, error: str, caplog
) -> None:
"""Test that prohibited augmented assignment operations raise an error."""
hass.async_add_executor_job(execute, hass, "aug_assign_prohibited.py", case, {})
await hass.async_block_till_done()
assert error in caplog.text