Enable augmented-assignment operations in scripts (#108081)
This commit is contained in:
parent
33ff6b5b6e
commit
5d23a1f84f
2 changed files with 79 additions and 0 deletions
|
@ -2,8 +2,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
from numbers import Number
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from RestrictedPython import (
|
from RestrictedPython import (
|
||||||
compile_restricted_exec,
|
compile_restricted_exec,
|
||||||
|
@ -146,6 +149,36 @@ def discover_scripts(hass):
|
||||||
async_set_service_schema(hass, DOMAIN, name, service_desc)
|
async_set_service_schema(hass, DOMAIN, name, service_desc)
|
||||||
|
|
||||||
|
|
||||||
|
IOPERATOR_TO_OPERATOR = {
|
||||||
|
"%=": operator.mod,
|
||||||
|
"&=": operator.and_,
|
||||||
|
"**=": operator.pow,
|
||||||
|
"*=": operator.mul,
|
||||||
|
"+=": operator.add,
|
||||||
|
"-=": operator.sub,
|
||||||
|
"//=": operator.floordiv,
|
||||||
|
"/=": operator.truediv,
|
||||||
|
"<<=": operator.lshift,
|
||||||
|
">>=": operator.rshift,
|
||||||
|
"@=": operator.matmul,
|
||||||
|
"^=": operator.xor,
|
||||||
|
"|=": operator.or_,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any:
|
||||||
|
"""Implement augmented-assign (+=, -=, etc.) operators for restricted code.
|
||||||
|
|
||||||
|
See RestrictedPython's `visit_AugAssign` for details.
|
||||||
|
"""
|
||||||
|
if not isinstance(target, (list, Number, str)):
|
||||||
|
raise ScriptError(f"The {op!r} operation is not allowed on a {type(target)}")
|
||||||
|
op_fun = IOPERATOR_TO_OPERATOR.get(op)
|
||||||
|
if not op_fun:
|
||||||
|
raise ScriptError(f"The {op!r} operation is not allowed")
|
||||||
|
return op_fun(target, operand)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def execute_script(hass, name, data=None, return_response=False):
|
def execute_script(hass, name, data=None, return_response=False):
|
||||||
"""Execute a script."""
|
"""Execute a script."""
|
||||||
|
@ -223,6 +256,7 @@ def execute(hass, filename, source, data=None, return_response=False):
|
||||||
"_getitem_": default_guarded_getitem,
|
"_getitem_": default_guarded_getitem,
|
||||||
"_iter_unpack_sequence_": guarded_iter_unpack_sequence,
|
"_iter_unpack_sequence_": guarded_iter_unpack_sequence,
|
||||||
"_unpack_sequence_": guarded_unpack_sequence,
|
"_unpack_sequence_": guarded_unpack_sequence,
|
||||||
|
"_inplacevar_": guarded_inplacevar,
|
||||||
"hass": hass,
|
"hass": hass,
|
||||||
"data": data or {},
|
"data": data or {},
|
||||||
"logger": logger,
|
"logger": logger,
|
||||||
|
|
|
@ -596,3 +596,48 @@ output = f"hello {data.get('name', 'World')}"
|
||||||
blocking=True,
|
blocking=True,
|
||||||
return_response=True,
|
return_response=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_augmented_assignment_operations(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that augmented assignment operations work."""
|
||||||
|
source = """
|
||||||
|
a = 10
|
||||||
|
a += 20
|
||||||
|
a *= 5
|
||||||
|
a -= 8
|
||||||
|
b = "foo"
|
||||||
|
b += "bar"
|
||||||
|
b *= 2
|
||||||
|
c = []
|
||||||
|
c += [1, 2, 3]
|
||||||
|
c *= 2
|
||||||
|
hass.states.set('hello.a', a)
|
||||||
|
hass.states.set('hello.b', b)
|
||||||
|
hass.states.set('hello.c', c)
|
||||||
|
"""
|
||||||
|
|
||||||
|
hass.async_add_executor_job(execute, hass, "aug_assign.py", source, {})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.states.get("hello.a").state == str(((10 + 20) * 5) - 8)
|
||||||
|
assert hass.states.get("hello.b").state == ("foo" + "bar") * 2
|
||||||
|
assert hass.states.get("hello.c").state == str([1, 2, 3] * 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("case", "error"),
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
"d = datetime.date(2024, 1, 1); d += 5",
|
||||||
|
"The '+=' operation is not allowed",
|
||||||
|
id="datetime.date",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_prohibited_augmented_assignment_operations(
|
||||||
|
hass: HomeAssistant, case: str, error: str, caplog
|
||||||
|
) -> None:
|
||||||
|
"""Test that prohibited augmented assignment operations raise an error."""
|
||||||
|
hass.async_add_executor_job(execute, hass, "aug_assign_prohibited.py", case, {})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert error in caplog.text
|
||||||
|
|
Loading…
Add table
Reference in a new issue