Enable strict typing on script helper (#122075)

This commit is contained in:
Erik Montnemery 2024-07-17 13:51:59 +02:00 committed by GitHub
parent a0f91d27a3
commit efb7bede40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 40 deletions

View file

@ -21,6 +21,7 @@ homeassistant.helpers.entity_platform
homeassistant.helpers.entity_values homeassistant.helpers.entity_values
homeassistant.helpers.event homeassistant.helpers.event
homeassistant.helpers.reload homeassistant.helpers.reload
homeassistant.helpers.script
homeassistant.helpers.script_variables homeassistant.helpers.script_variables
homeassistant.helpers.singleton homeassistant.helpers.singleton
homeassistant.helpers.sun homeassistant.helpers.sun

View file

@ -13,7 +13,7 @@ from functools import cached_property, partial
import itertools import itertools
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Literal, TypedDict, cast from typing import Any, Literal, TypedDict, cast, overload
import async_interrupt import async_interrupt
import voluptuous as vol import voluptuous as vol
@ -75,6 +75,7 @@ from homeassistant.core import (
HassJob, HassJob,
HomeAssistant, HomeAssistant,
ServiceResponse, ServiceResponse,
State,
SupportsResponse, SupportsResponse,
callback, callback,
) )
@ -107,9 +108,7 @@ from .trace import (
trace_update_result, trace_update_result,
) )
from .trigger import async_initialize_triggers, async_validate_trigger_config from .trigger import async_initialize_triggers, async_validate_trigger_config
from .typing import UNDEFINED, ConfigType, UndefinedType from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUED = "queued" SCRIPT_MODE_QUEUED = "queued"
@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
future.set_result(None) future.set_result(None)
def action_trace_append(variables, path): def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement:
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables, path) trace_element = TraceElement(variables, path)
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN) trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
@ -430,7 +429,7 @@ class _ScriptRun:
if not self._stop.done(): if not self._stop.done():
self._script._changed() # noqa: SLF001 self._script._changed() # noqa: SLF001
async def _async_get_condition(self, config): async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
return await self._script._async_get_condition(config) # noqa: SLF001 return await self._script._async_get_condition(config) # noqa: SLF001
def _log( def _log(
@ -438,7 +437,7 @@ class _ScriptRun:
) -> None: ) -> None:
self._script._log(msg, *args, level=level, **kwargs) # noqa: SLF001 self._script._log(msg, *args, level=level, **kwargs) # noqa: SLF001
def _step_log(self, default_message, timeout=None): def _step_log(self, default_message: str, timeout: float | None = None) -> None:
self._script.last_action = self._action.get(CONF_ALIAS, default_message) self._script.last_action = self._action.get(CONF_ALIAS, default_message)
_timeout = ( _timeout = (
"" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})" "" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})"
@ -580,7 +579,7 @@ class _ScriptRun:
if not isinstance(exception, exceptions.HomeAssistantError): if not isinstance(exception, exceptions.HomeAssistantError):
raise exception raise exception
def _log_exception(self, exception): def _log_exception(self, exception: Exception) -> None:
action_type = cv.determine_script_action(self._action) action_type = cv.determine_script_action(self._action)
error = str(exception) error = str(exception)
@ -629,7 +628,7 @@ class _ScriptRun:
) )
raise _AbortScript from ex raise _AbortScript from ex
async def _async_delay_step(self): async def _async_delay_step(self) -> None:
"""Handle delay.""" """Handle delay."""
delay_delta = self._get_pos_time_period_template(CONF_DELAY) delay_delta = self._get_pos_time_period_template(CONF_DELAY)
@ -661,7 +660,7 @@ class _ScriptRun:
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
return None return None
async def _async_wait_template_step(self): async def _async_wait_template_step(self) -> None:
"""Handle a wait template.""" """Handle a wait template."""
timeout = self._get_timeout_seconds_from_action() timeout = self._get_timeout_seconds_from_action()
self._step_log("wait template", timeout) self._step_log("wait template", timeout)
@ -690,7 +689,9 @@ class _ScriptRun:
futures.append(done) futures.append(done)
@callback @callback
def async_script_wait(entity_id, from_s, to_s): def async_script_wait(
entity_id: str, from_s: State | None, to_s: State | None
) -> None:
"""Handle script after template condition is true.""" """Handle script after template condition is true."""
self._async_set_remaining_time_var(timeout_handle) self._async_set_remaining_time_var(timeout_handle)
self._variables["wait"]["completed"] = True self._variables["wait"]["completed"] = True
@ -727,7 +728,7 @@ class _ScriptRun:
except ScriptStoppedError as ex: except ScriptStoppedError as ex:
raise asyncio.CancelledError from ex raise asyncio.CancelledError from ex
async def _async_call_service_step(self): async def _async_call_service_step(self) -> None:
"""Call the service specified in the action.""" """Call the service specified in the action."""
self._step_log("call service") self._step_log("call service")
@ -774,14 +775,14 @@ class _ScriptRun:
if response_variable: if response_variable:
self._variables[response_variable] = response_data self._variables[response_variable] = response_data
async def _async_device_step(self): async def _async_device_step(self) -> None:
"""Perform the device automation specified in the action.""" """Perform the device automation specified in the action."""
self._step_log("device automation") self._step_log("device automation")
await device_action.async_call_action_from_config( await device_action.async_call_action_from_config(
self._hass, self._action, self._variables, self._context self._hass, self._action, self._variables, self._context
) )
async def _async_scene_step(self): async def _async_scene_step(self) -> None:
"""Activate the scene specified in the action.""" """Activate the scene specified in the action."""
self._step_log("activate scene") self._step_log("activate scene")
trace_set_result(scene=self._action[CONF_SCENE]) trace_set_result(scene=self._action[CONF_SCENE])
@ -793,7 +794,7 @@ class _ScriptRun:
context=self._context, context=self._context,
) )
async def _async_event_step(self): async def _async_event_step(self) -> None:
"""Fire an event.""" """Fire an event."""
self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT])) self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT]))
event_data = {} event_data = {}
@ -815,7 +816,7 @@ class _ScriptRun:
self._action[CONF_EVENT], event_data, context=self._context self._action[CONF_EVENT], event_data, context=self._context
) )
async def _async_condition_step(self): async def _async_condition_step(self) -> None:
"""Test if condition is matching.""" """Test if condition is matching."""
self._script.last_action = self._action.get( self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION] CONF_ALIAS, self._action[CONF_CONDITION]
@ -835,12 +836,19 @@ class _ScriptRun:
if not check: if not check:
raise _ConditionFail raise _ConditionFail
def _test_conditions(self, conditions, name, condition_path=None): def _test_conditions(
self,
conditions: list[ConditionCheckerType],
name: str,
condition_path: str | None = None,
) -> bool | None:
if condition_path is None: if condition_path is None:
condition_path = name condition_path = name
@trace_condition_function @trace_condition_function
def traced_test_conditions(hass, variables): def traced_test_conditions(
hass: HomeAssistant, variables: TemplateVarsType
) -> bool | None:
try: try:
with trace_path(condition_path): with trace_path(condition_path):
for idx, cond in enumerate(conditions): for idx, cond in enumerate(conditions):
@ -856,7 +864,7 @@ class _ScriptRun:
return traced_test_conditions(self._hass, self._variables) return traced_test_conditions(self._hass, self._variables)
@async_trace_path("repeat") @async_trace_path("repeat")
async def _async_repeat_step(self): # noqa: C901 async def _async_repeat_step(self) -> None: # noqa: C901
"""Repeat a sequence.""" """Repeat a sequence."""
description = self._action.get(CONF_ALIAS, "sequence") description = self._action.get(CONF_ALIAS, "sequence")
repeat = self._action[CONF_REPEAT] repeat = self._action[CONF_REPEAT]
@ -876,7 +884,7 @@ class _ScriptRun:
script = self._script._get_repeat_script(self._step) # noqa: SLF001 script = self._script._get_repeat_script(self._step) # noqa: SLF001
warned_too_many_loops = False warned_too_many_loops = False
async def async_run_sequence(iteration, extra_msg=""): async def async_run_sequence(iteration: int, extra_msg: str = "") -> None:
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
with trace_path("sequence"): with trace_path("sequence"):
await self._async_run_script(script) await self._async_run_script(script)
@ -1052,7 +1060,7 @@ class _ScriptRun:
"""If sequence.""" """If sequence."""
if_data = await self._script._async_get_if_data(self._step) # noqa: SLF001 if_data = await self._script._async_get_if_data(self._step) # noqa: SLF001
test_conditions = False test_conditions: bool | None = False
try: try:
with trace_path("if"): with trace_path("if"):
test_conditions = self._test_conditions( test_conditions = self._test_conditions(
@ -1072,6 +1080,26 @@ class _ScriptRun:
with trace_path("else"): with trace_path("else"):
await self._async_run_script(if_data["if_else"]) await self._async_run_script(if_data["if_else"])
@overload
def _async_futures_with_timeout(
self,
timeout: float,
) -> tuple[
list[asyncio.Future[None]],
asyncio.TimerHandle,
asyncio.Future[None],
]: ...
@overload
def _async_futures_with_timeout(
self,
timeout: None,
) -> tuple[
list[asyncio.Future[None]],
None,
None,
]: ...
def _async_futures_with_timeout( def _async_futures_with_timeout(
self, self,
timeout: float | None, timeout: float | None,
@ -1098,7 +1126,7 @@ class _ScriptRun:
futures.append(timeout_future) futures.append(timeout_future)
return futures, timeout_handle, timeout_future return futures, timeout_handle, timeout_future
async def _async_wait_for_trigger_step(self): async def _async_wait_for_trigger_step(self) -> None:
"""Wait for a trigger event.""" """Wait for a trigger event."""
timeout = self._get_timeout_seconds_from_action() timeout = self._get_timeout_seconds_from_action()
@ -1119,12 +1147,14 @@ class _ScriptRun:
done = self._hass.loop.create_future() done = self._hass.loop.create_future()
futures.append(done) futures.append(done)
async def async_done(variables, context=None): async def async_done(
variables: dict[str, Any], context: Context | None = None
) -> None:
self._async_set_remaining_time_var(timeout_handle) self._async_set_remaining_time_var(timeout_handle)
self._variables["wait"]["trigger"] = variables["trigger"] self._variables["wait"]["trigger"] = variables["trigger"]
_set_result_unless_done(done) _set_result_unless_done(done)
def log_cb(level, msg, **kwargs): def log_cb(level: int, msg: str, **kwargs: Any) -> None:
self._log(msg, level=level, **kwargs) self._log(msg, level=level, **kwargs)
remove_triggers = await async_initialize_triggers( remove_triggers = await async_initialize_triggers(
@ -1168,14 +1198,14 @@ class _ScriptRun:
unsub() unsub()
async def _async_variables_step(self): async def _async_variables_step(self) -> None:
"""Set a variable value.""" """Set a variable value."""
self._step_log("setting variables") self._step_log("setting variables")
self._variables = self._action[CONF_VARIABLES].async_render( self._variables = self._action[CONF_VARIABLES].async_render(
self._hass, self._variables, render_as_defaults=False self._hass, self._variables, render_as_defaults=False
) )
async def _async_set_conversation_response_step(self): async def _async_set_conversation_response_step(self) -> None:
"""Set conversation response.""" """Set conversation response."""
self._step_log("setting conversation response") self._step_log("setting conversation response")
resp: template.Template | None = self._action[CONF_SET_CONVERSATION_RESPONSE] resp: template.Template | None = self._action[CONF_SET_CONVERSATION_RESPONSE]
@ -1187,7 +1217,7 @@ class _ScriptRun:
) )
trace_set_result(conversation_response=self._conversation_response) trace_set_result(conversation_response=self._conversation_response)
async def _async_stop_step(self): async def _async_stop_step(self) -> None:
"""Stop script execution.""" """Stop script execution."""
stop = self._action[CONF_STOP] stop = self._action[CONF_STOP]
error = self._action.get(CONF_ERROR, False) error = self._action.get(CONF_ERROR, False)
@ -1320,7 +1350,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
) )
type _VarsType = dict[str, Any] | MappingProxyType type _VarsType = dict[str, Any] | MappingProxyType[str, Any]
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None: def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
@ -1358,7 +1388,7 @@ class ScriptRunResult:
conversation_response: str | None | UndefinedType conversation_response: str | None | UndefinedType
service_response: ServiceResponse service_response: ServiceResponse
variables: dict variables: dict[str, Any]
class Script: class Script:
@ -1413,7 +1443,7 @@ class Script:
self._set_logger(logger) self._set_logger(logger)
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self.last_action = None self.last_action: str | None = None
self.last_triggered: datetime | None = None self.last_triggered: datetime | None = None
self._runs: list[_ScriptRun] = [] self._runs: list[_ScriptRun] = []
@ -1421,7 +1451,7 @@ class Script:
self._max_exceeded = max_exceeded self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED: if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock() self._queue_lck = asyncio.Lock()
self._config_cache: dict[set[tuple], Callable[..., bool]] = {} self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
self._repeat_script: dict[int, Script] = {} self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {} self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {} self._if_data: dict[int, _IfData] = {}
@ -1714,9 +1744,11 @@ class Script:
variables["context"] = context variables["context"] = context
elif self._copy_variables_on_run: elif self._copy_variables_on_run:
variables = cast(dict, copy(run_variables)) # This is not the top level script, variables have been turned to a dict
variables = cast(dict[str, Any], copy(run_variables))
else: else:
variables = cast(dict, run_variables) # This is not the top level script, variables have been turned to a dict
variables = cast(dict[str, Any], run_variables)
# Prevent non-allowed recursive calls which will cause deadlocks when we try to # Prevent non-allowed recursive calls which will cause deadlocks when we try to
# stop (restart) or wait for (queued) our own script run. # stop (restart) or wait for (queued) our own script run.
@ -1745,9 +1777,7 @@ class Script:
cls = _ScriptRun cls = _ScriptRun
else: else:
cls = _QueuedScriptRun cls = _QueuedScriptRun
run = cls( run = cls(self._hass, self, variables, context, self._log_exceptions)
self._hass, self, cast(dict, variables), context, self._log_exceptions
)
has_existing_runs = bool(self._runs) has_existing_runs = bool(self._runs)
self._runs.append(run) self._runs.append(run)
if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs: if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs:
@ -1772,7 +1802,9 @@ class Script:
self._changed() self._changed()
raise raise
async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None: async def _async_stop(
self, aws: list[asyncio.Task[None]], update_state: bool
) -> None:
await asyncio.wait(aws) await asyncio.wait(aws)
if update_state: if update_state:
self._changed() self._changed()
@ -1791,7 +1823,7 @@ class Script:
return return
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state))) await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
async def _async_get_condition(self, config): async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
config_cache_key = frozenset((k, str(v)) for k, v in config.items()) config_cache_key = frozenset((k, str(v)) for k, v in config.items())
if not (cond := self._config_cache.get(config_cache_key)): if not (cond := self._config_cache.get(config_cache_key)):
cond = await condition.async_from_config(self._hass, config) cond = await condition.async_from_config(self._hass, config)

View file

@ -34,7 +34,7 @@ class TraceElement:
"""Container for trace data.""" """Container for trace data."""
self._child_key: str | None = None self._child_key: str | None = None
self._child_run_id: str | None = None self._child_run_id: str | None = None
self._error: Exception | None = None self._error: BaseException | None = None
self.path: str = path self.path: str = path
self._result: dict[str, Any] | None = None self._result: dict[str, Any] | None = None
self.reuse_by_child = False self.reuse_by_child = False
@ -52,7 +52,7 @@ class TraceElement:
self._child_key = child_key self._child_key = child_key
self._child_run_id = child_run_id self._child_run_id = child_run_id
def set_error(self, ex: Exception) -> None: def set_error(self, ex: BaseException | None) -> None:
"""Set error.""" """Set error."""
self._error = ex self._error = ex

View file

@ -85,6 +85,9 @@ disallow_any_generics = true
[mypy-homeassistant.helpers.reload] [mypy-homeassistant.helpers.reload]
disallow_any_generics = true disallow_any_generics = true
[mypy-homeassistant.helpers.script]
disallow_any_generics = true
[mypy-homeassistant.helpers.script_variables] [mypy-homeassistant.helpers.script_variables]
disallow_any_generics = true disallow_any_generics = true