Enable strict typing on script integration (#122079)

This commit is contained in:
Erik Montnemery 2024-07-17 15:11:06 +02:00 committed by GitHub
parent 1e8da192b6
commit 07ceafed62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 56 additions and 24 deletions

View file

@ -385,6 +385,7 @@ homeassistant.components.samsungtv.*
homeassistant.components.scene.* homeassistant.components.scene.*
homeassistant.components.schedule.* homeassistant.components.schedule.*
homeassistant.components.scrape.* homeassistant.components.scrape.*
homeassistant.components.script.*
homeassistant.components.search.* homeassistant.components.search.*
homeassistant.components.select.* homeassistant.components.select.*
homeassistant.components.sensibo.* homeassistant.components.sensibo.*

View file

@ -50,6 +50,7 @@ from homeassistant.helpers.script import (
CONF_MAX, CONF_MAX,
CONF_MAX_EXCEEDED, CONF_MAX_EXCEEDED,
Script, Script,
ScriptRunResult,
script_stack_cv, script_stack_cv,
) )
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
@ -82,7 +83,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
@bind_hass @bind_hass
def is_on(hass, entity_id): def is_on(hass: HomeAssistant, entity_id: str) -> bool:
"""Return if the script is on based on the statemachine.""" """Return if the script is on based on the statemachine."""
return hass.states.is_state(entity_id, STATE_ON) return hass.states.is_state(entity_id, STATE_ON)
@ -498,8 +499,16 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
icon = None icon = None
_attr_should_poll = False _attr_should_poll = False
_attr_unique_id: str
def __init__(self, hass, key, cfg, raw_config, blueprint_inputs): def __init__(
self,
hass: HomeAssistant,
key: str,
cfg: ConfigType,
raw_config: ConfigType | None,
blueprint_inputs: ConfigType | None,
) -> None:
"""Initialize the script.""" """Initialize the script."""
self.icon = cfg.get(CONF_ICON) self.icon = cfg.get(CONF_ICON)
self.description = cfg[CONF_DESCRIPTION] self.description = cfg[CONF_DESCRIPTION]
@ -529,7 +538,7 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
self._attr_name = self.script.name self._attr_name = self.script.name
@property @property
def extra_state_attributes(self): def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes.""" """Return the state attributes."""
script = self.script script = self.script
attrs = { attrs = {
@ -544,7 +553,7 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
return attrs return attrs
@property @property
def is_on(self): def is_on(self) -> bool:
"""Return true if script is on.""" """Return true if script is on."""
return self.script.is_running return self.script.is_running
@ -564,11 +573,12 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
return self.script.referenced_areas return self.script.referenced_areas
@property @property
def referenced_blueprint(self): def referenced_blueprint(self) -> str | None:
"""Return referenced blueprint or None.""" """Return referenced blueprint or None."""
if self._blueprint_inputs is None: if self._blueprint_inputs is None:
return None return None
return self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH] path: str = self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH]
return path
@cached_property @cached_property
def referenced_devices(self) -> set[str]: def referenced_devices(self) -> set[str]:
@ -581,24 +591,24 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
return self.script.referenced_entities return self.script.referenced_entities
@callback @callback
def async_change_listener(self): def async_change_listener(self) -> None:
"""Update state.""" """Update state."""
self.async_write_ha_state() self.async_write_ha_state()
self._changed.set() self._changed.set()
async def async_turn_on(self, **kwargs): async def async_turn_on(self, **kwargs: Any) -> None:
"""Run the script. """Run the script.
Depending on the script's run mode, this may do nothing, restart the script or Depending on the script's run mode, this may do nothing, restart the script or
fire an additional parallel run. fire an additional parallel run.
""" """
variables = kwargs.get("variables") variables: dict[str, Any] | None = kwargs.get("variables")
context = kwargs.get("context") context: Context = kwargs["context"]
wait = kwargs.get("wait", True) wait: bool = kwargs.get("wait", True)
await self._async_start_run(variables, context, wait) await self._async_start_run(variables, context, wait)
async def _async_start_run( async def _async_start_run(
self, variables: dict, context: Context, wait: bool self, variables: dict[str, Any] | None, context: Context, wait: bool
) -> ServiceResponse: ) -> ServiceResponse:
"""Start the run of a script.""" """Start the run of a script."""
self.async_set_context(context) self.async_set_context(context)
@ -633,10 +643,12 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
await self._changed.wait() await self._changed.wait()
return None return None
async def _async_run(self, variables, context): async def _async_run(
self, variables: dict[str, Any] | None, context: Context
) -> ScriptRunResult | None:
with trace_script( with trace_script(
self.hass, self.hass,
self.unique_id, self._attr_unique_id,
self.raw_config, self.raw_config,
self._blueprint_inputs, self._blueprint_inputs,
context, context,
@ -651,7 +663,7 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
script_vars = {"this": this, **(variables or {})} script_vars = {"this": this, **(variables or {})}
return await self.script.async_run(script_vars, context) return await self.script.async_run(script_vars, context)
async def async_turn_off(self, **kwargs): async def async_turn_off(self, **kwargs: Any) -> None:
"""Stop running the script. """Stop running the script.
If multiple runs are in progress, all will be stopped. If multiple runs are in progress, all will be stopped.
@ -696,12 +708,12 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
): ):
self.script.last_triggered = parse_datetime(last_triggered) self.script.last_triggered = parse_datetime(last_triggered)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self) -> None:
"""Stop script and remove service when it will be removed from HA.""" """Stop script and remove service when it will be removed from HA."""
await self.script.async_stop() await self.script.async_stop()
# remove service # remove service
self.hass.services.async_remove(DOMAIN, self.unique_id) self.hass.services.async_remove(DOMAIN, self._attr_unique_id)
@websocket_api.websocket_command({"type": "script/config", "entity_id": str}) @websocket_api.websocket_command({"type": "script/config", "entity_id": str})

View file

@ -256,7 +256,7 @@ async def async_validate_config_item(
return await _async_validate_config_item(hass, object_id, config, True, False) return await _async_validate_config_item(hass, object_id, config, True, False)
async def async_validate_config(hass, config): async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> ConfigType:
"""Validate config.""" """Validate config."""
scripts = {} scripts = {}
for _, p_config in config_per_platform(config, DOMAIN): for _, p_config in config_per_platform(config, DOMAIN):

View file

@ -1,24 +1,33 @@
"""Describe logbook events.""" """Describe logbook events."""
from collections.abc import Callable
from typing import Any
from homeassistant.components.logbook import ( from homeassistant.components.logbook import (
LOGBOOK_ENTRY_CONTEXT_ID, LOGBOOK_ENTRY_CONTEXT_ID,
LOGBOOK_ENTRY_ENTITY_ID, LOGBOOK_ENTRY_ENTITY_ID,
LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_MESSAGE,
LOGBOOK_ENTRY_NAME, LOGBOOK_ENTRY_NAME,
LazyEventPartialState,
) )
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from . import DOMAIN, EVENT_SCRIPT_STARTED from . import DOMAIN, EVENT_SCRIPT_STARTED
@callback @callback
def async_describe_events(hass, async_describe_event): def async_describe_events(
hass: HomeAssistant,
async_describe_event: Callable[
[str, str, Callable[[LazyEventPartialState], dict[str, Any]]], None
],
) -> None:
"""Describe logbook events.""" """Describe logbook events."""
@callback @callback
def async_describe_logbook_event(event): def async_describe_logbook_event(event: LazyEventPartialState) -> dict[str, Any]:
"""Describe the logbook event.""" """Describe a logbook event."""
data = event.data data = event.data
return { return {
LOGBOOK_ENTRY_NAME: data.get(ATTR_NAME), LOGBOOK_ENTRY_NAME: data.get(ATTR_NAME),

View file

@ -26,8 +26,8 @@ class ScriptTrace(ActionTrace):
def trace_script( def trace_script(
hass: HomeAssistant, hass: HomeAssistant,
item_id: str, item_id: str,
config: dict[str, Any], config: dict[str, Any] | None,
blueprint_inputs: dict[str, Any], blueprint_inputs: dict[str, Any] | None,
context: Context, context: Context,
trace_config: dict[str, Any], trace_config: dict[str, Any],
) -> Iterator[ScriptTrace]: ) -> Iterator[ScriptTrace]:

View file

@ -3606,6 +3606,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.script.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.search.*] [mypy-homeassistant.components.search.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true