Helpers type hint improvements (#44964)
This commit is contained in:
parent
3569d92385
commit
3a88a4120e
4 changed files with 47 additions and 45 deletions
|
@ -556,7 +556,7 @@ def template(value: Optional[Any]) -> template_helper.Template:
|
||||||
template_value = template_helper.Template(str(value)) # type: ignore
|
template_value = template_helper.Template(str(value)) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
template_value.ensure_valid() # type: ignore[no-untyped-call]
|
template_value.ensure_valid()
|
||||||
return template_value
|
return template_value
|
||||||
except TemplateError as ex:
|
except TemplateError as ex:
|
||||||
raise vol.Invalid(f"invalid template ({ex})") from ex
|
raise vol.Invalid(f"invalid template ({ex})") from ex
|
||||||
|
@ -574,7 +574,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
|
||||||
|
|
||||||
template_value = template_helper.Template(str(value)) # type: ignore
|
template_value = template_helper.Template(str(value)) # type: ignore
|
||||||
try:
|
try:
|
||||||
template_value.ensure_valid() # type: ignore[no-untyped-call]
|
template_value.ensure_valid()
|
||||||
return template_value
|
return template_value
|
||||||
except TemplateError as ex:
|
except TemplateError as ex:
|
||||||
raise vol.Invalid(f"invalid template ({ex})") from ex
|
raise vol.Invalid(f"invalid template ({ex})") from ex
|
||||||
|
|
|
@ -26,7 +26,6 @@ from .event import async_call_later, async_track_time_interval
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .entity import Entity
|
from .entity import Entity
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
|
|
||||||
SLOW_SETUP_WARNING = 10
|
SLOW_SETUP_WARNING = 10
|
||||||
SLOW_SETUP_MAX_WAIT = 60
|
SLOW_SETUP_MAX_WAIT = 60
|
||||||
|
@ -81,7 +80,7 @@ class EntityPlatform:
|
||||||
self.platform_name, []
|
self.platform_name, []
|
||||||
).append(self)
|
).append(self)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Represent an EntityPlatform."""
|
"""Represent an EntityPlatform."""
|
||||||
return f"<EntityPlatform domain={self.domain} platform_name={self.platform_name} config_entry={self.config_entry}>"
|
return f"<EntityPlatform domain={self.domain} platform_name={self.platform_name} config_entry={self.config_entry}>"
|
||||||
|
|
||||||
|
@ -116,7 +115,7 @@ class EntityPlatform:
|
||||||
|
|
||||||
return self.parallel_updates
|
return self.parallel_updates
|
||||||
|
|
||||||
async def async_setup(self, platform_config, discovery_info=None):
|
async def async_setup(self, platform_config, discovery_info=None): # type: ignore[no-untyped-def]
|
||||||
"""Set up the platform from a config file."""
|
"""Set up the platform from a config file."""
|
||||||
platform = self.platform
|
platform = self.platform
|
||||||
hass = self.hass
|
hass = self.hass
|
||||||
|
@ -162,7 +161,7 @@ class EntityPlatform:
|
||||||
platform = self.platform
|
platform = self.platform
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_setup_task():
|
def async_create_setup_task(): # type: ignore[no-untyped-def]
|
||||||
"""Get task to set up platform."""
|
"""Get task to set up platform."""
|
||||||
return platform.async_setup_entry( # type: ignore
|
return platform.async_setup_entry( # type: ignore
|
||||||
self.hass, config_entry, self._async_schedule_add_entities
|
self.hass, config_entry, self._async_schedule_add_entities
|
||||||
|
@ -218,7 +217,7 @@ class EntityPlatform:
|
||||||
wait_time,
|
wait_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def setup_again(now):
|
async def setup_again(now): # type: ignore[no-untyped-def]
|
||||||
"""Run setup again."""
|
"""Run setup again."""
|
||||||
self._async_cancel_retry_setup = None
|
self._async_cancel_retry_setup = None
|
||||||
await self._async_setup_platform(async_create_setup_task, tries)
|
await self._async_setup_platform(async_create_setup_task, tries)
|
||||||
|
@ -340,7 +339,7 @@ class EntityPlatform:
|
||||||
self.scan_interval,
|
self.scan_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_add_entity(
|
async def _async_add_entity( # type: ignore[no-untyped-def]
|
||||||
self, entity, update_before_add, entity_registry, device_registry
|
self, entity, update_before_add, entity_registry, device_registry
|
||||||
):
|
):
|
||||||
"""Add an entity to the platform."""
|
"""Add an entity to the platform."""
|
||||||
|
@ -560,7 +559,7 @@ class EntityPlatform:
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_entity_service(self, name, schema, func, required_features=None):
|
def async_register_entity_service(self, name, schema, func, required_features=None): # type: ignore[no-untyped-def]
|
||||||
"""Register an entity service.
|
"""Register an entity service.
|
||||||
|
|
||||||
Services will automatically be shared by all platforms of the same domain.
|
Services will automatically be shared by all platforms of the same domain.
|
||||||
|
|
|
@ -219,7 +219,7 @@ class _ScriptRun:
|
||||||
self._stop = asyncio.Event()
|
self._stop = asyncio.Event()
|
||||||
self._stopped = asyncio.Event()
|
self._stopped = asyncio.Event()
|
||||||
|
|
||||||
def _changed(self):
|
def _changed(self) -> None:
|
||||||
if not self._stop.is_set():
|
if not self._stop.is_set():
|
||||||
self._script._changed() # pylint: disable=protected-access
|
self._script._changed() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@ -227,7 +227,7 @@ class _ScriptRun:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return await self._script._async_get_condition(config)
|
return await self._script._async_get_condition(config)
|
||||||
|
|
||||||
def _log(self, msg, *args, level=logging.INFO):
|
def _log(self, msg: str, *args: Any, level: int = logging.INFO) -> None:
|
||||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||||
|
|
||||||
async def async_run(self) -> None:
|
async def async_run(self) -> None:
|
||||||
|
@ -257,7 +257,7 @@ class _ScriptRun:
|
||||||
self._log_exception(ex)
|
self._log_exception(ex)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _finish(self):
|
def _finish(self) -> None:
|
||||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||||
if not self._script.is_running:
|
if not self._script.is_running:
|
||||||
self._script.last_action = None
|
self._script.last_action = None
|
||||||
|
@ -389,7 +389,7 @@ class _ScriptRun:
|
||||||
async def _async_run_long_action(self, long_task):
|
async def _async_run_long_action(self, long_task):
|
||||||
"""Run a long task while monitoring for stop request."""
|
"""Run a long task while monitoring for stop request."""
|
||||||
|
|
||||||
async def async_cancel_long_task():
|
async def async_cancel_long_task() -> None:
|
||||||
# Stop long task and wait for it to finish.
|
# Stop long task and wait for it to finish.
|
||||||
long_task.cancel()
|
long_task.cancel()
|
||||||
try:
|
try:
|
||||||
|
@ -586,7 +586,7 @@ class _ScriptRun:
|
||||||
else:
|
else:
|
||||||
del self._variables["repeat"]
|
del self._variables["repeat"]
|
||||||
|
|
||||||
async def _async_choose_step(self):
|
async def _async_choose_step(self) -> None:
|
||||||
"""Choose a sequence."""
|
"""Choose a sequence."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
choose_data = await self._script._async_get_choose_data(self._step)
|
choose_data = await self._script._async_get_choose_data(self._step)
|
||||||
|
@ -706,7 +706,7 @@ class _QueuedScriptRun(_ScriptRun):
|
||||||
else:
|
else:
|
||||||
await super().async_run()
|
await super().async_run()
|
||||||
|
|
||||||
def _finish(self):
|
def _finish(self) -> None:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if self.lock_acquired:
|
if self.lock_acquired:
|
||||||
self._script._queue_lck.release()
|
self._script._queue_lck.release()
|
||||||
|
@ -868,7 +868,7 @@ class Script:
|
||||||
if choose_data["default"]:
|
if choose_data["default"]:
|
||||||
choose_data["default"].update_logger(self._logger)
|
choose_data["default"].update_logger(self._logger)
|
||||||
|
|
||||||
def _changed(self):
|
def _changed(self) -> None:
|
||||||
if self._change_listener_job:
|
if self._change_listener_job:
|
||||||
self._hass.async_run_hass_job(self._change_listener_job)
|
self._hass.async_run_hass_job(self._change_listener_job)
|
||||||
|
|
||||||
|
@ -898,7 +898,7 @@ class Script:
|
||||||
if self._referenced_devices is not None:
|
if self._referenced_devices is not None:
|
||||||
return self._referenced_devices
|
return self._referenced_devices
|
||||||
|
|
||||||
referenced = set()
|
referenced: Set[str] = set()
|
||||||
|
|
||||||
for step in self.sequence:
|
for step in self.sequence:
|
||||||
action = cv.determine_script_action(step)
|
action = cv.determine_script_action(step)
|
||||||
|
@ -927,7 +927,7 @@ class Script:
|
||||||
if self._referenced_entities is not None:
|
if self._referenced_entities is not None:
|
||||||
return self._referenced_entities
|
return self._referenced_entities
|
||||||
|
|
||||||
referenced = set()
|
referenced: Set[str] = set()
|
||||||
|
|
||||||
for step in self.sequence:
|
for step in self.sequence:
|
||||||
action = cv.determine_script_action(step)
|
action = cv.determine_script_action(step)
|
||||||
|
@ -1128,9 +1128,9 @@ class Script:
|
||||||
self._choose_data[step] = choose_data
|
self._choose_data[step] = choose_data
|
||||||
return choose_data
|
return choose_data
|
||||||
|
|
||||||
def _log(self, msg, *args, level=logging.INFO):
|
def _log(self, msg: str, *args: Any, level: int = logging.INFO) -> None:
|
||||||
msg = f"%s: {msg}"
|
msg = f"%s: {msg}"
|
||||||
args = [self.name, *args]
|
args = (self.name, *args)
|
||||||
|
|
||||||
if level == _LOG_EXCEPTION:
|
if level == _LOG_EXCEPTION:
|
||||||
self._logger.exception(msg, *args)
|
self._logger.exception(msg, *args)
|
||||||
|
|
|
@ -11,7 +11,7 @@ import math
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union
|
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union, cast
|
||||||
from urllib.parse import urlencode as urllib_urlencode
|
from urllib.parse import urlencode as urllib_urlencode
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
@ -38,8 +38,7 @@ from homeassistant.util import convert, dt as dt_util, location as loc_util
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
from homeassistant.util.thread import ThreadWithException
|
from homeassistant.util.thread import ThreadWithException
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
|
@ -140,7 +139,7 @@ def gen_result_wrapper(kls):
|
||||||
if kls is set:
|
if kls is set:
|
||||||
return str(set(self))
|
return str(set(self))
|
||||||
|
|
||||||
return kls.__str__(self)
|
return cast(str, kls.__str__(self))
|
||||||
|
|
||||||
return self.render_result
|
return self.render_result
|
||||||
|
|
||||||
|
@ -173,7 +172,8 @@ class TupleWrapper(tuple, ResultWrapper):
|
||||||
|
|
||||||
|
|
||||||
RESULT_WRAPPERS: Dict[Type, Type] = {
|
RESULT_WRAPPERS: Dict[Type, Type] = {
|
||||||
kls: gen_result_wrapper(kls) for kls in (list, dict, set)
|
kls: gen_result_wrapper(kls) # type: ignore[no-untyped-call]
|
||||||
|
for kls in (list, dict, set)
|
||||||
}
|
}
|
||||||
RESULT_WRAPPERS[tuple] = TupleWrapper
|
RESULT_WRAPPERS[tuple] = TupleWrapper
|
||||||
|
|
||||||
|
@ -195,15 +195,15 @@ class RenderInfo:
|
||||||
# Will be set sensibly once frozen.
|
# Will be set sensibly once frozen.
|
||||||
self.filter_lifecycle = _true
|
self.filter_lifecycle = _true
|
||||||
self.filter = _true
|
self.filter = _true
|
||||||
self._result = None
|
self._result: Optional[str] = None
|
||||||
self.is_static = False
|
self.is_static = False
|
||||||
self.exception = None
|
self.exception: Optional[TemplateError] = None
|
||||||
self.all_states = False
|
self.all_states = False
|
||||||
self.all_states_lifecycle = False
|
self.all_states_lifecycle = False
|
||||||
self.domains = set()
|
self.domains = set()
|
||||||
self.domains_lifecycle = set()
|
self.domains_lifecycle = set()
|
||||||
self.entities = set()
|
self.entities = set()
|
||||||
self.rate_limit = None
|
self.rate_limit: Optional[timedelta] = None
|
||||||
self.has_time = False
|
self.has_time = False
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -228,7 +228,7 @@ class RenderInfo:
|
||||||
"""Results of the template computation."""
|
"""Results of the template computation."""
|
||||||
if self.exception is not None:
|
if self.exception is not None:
|
||||||
raise self.exception
|
raise self.exception
|
||||||
return self._result
|
return cast(str, self._result)
|
||||||
|
|
||||||
def _freeze_static(self) -> None:
|
def _freeze_static(self) -> None:
|
||||||
self.is_static = True
|
self.is_static = True
|
||||||
|
@ -288,26 +288,26 @@ class Template:
|
||||||
|
|
||||||
self.template: str = template.strip()
|
self.template: str = template.strip()
|
||||||
self._compiled_code = None
|
self._compiled_code = None
|
||||||
self._compiled = None
|
self._compiled: Optional[Template] = None
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.is_static = not is_template_string(template)
|
self.is_static = not is_template_string(template)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _env(self):
|
def _env(self) -> "TemplateEnvironment":
|
||||||
if self.hass is None:
|
if self.hass is None:
|
||||||
return _NO_HASS_ENV
|
return _NO_HASS_ENV
|
||||||
ret = self.hass.data.get(_ENVIRONMENT)
|
ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass)
|
ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass) # type: ignore[no-untyped-call]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def ensure_valid(self):
|
def ensure_valid(self) -> None:
|
||||||
"""Return if template is valid."""
|
"""Return if template is valid."""
|
||||||
if self._compiled_code is not None:
|
if self._compiled_code is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._compiled_code = self._env.compile(self.template)
|
self._compiled_code = self._env.compile(self.template) # type: ignore[no-untyped-call]
|
||||||
except jinja2.TemplateError as err:
|
except jinja2.TemplateError as err:
|
||||||
raise TemplateError(err) from err
|
raise TemplateError(err) from err
|
||||||
|
|
||||||
|
@ -422,7 +422,7 @@ class Template:
|
||||||
|
|
||||||
finish_event = asyncio.Event()
|
finish_event = asyncio.Event()
|
||||||
|
|
||||||
def _render_template():
|
def _render_template() -> None:
|
||||||
try:
|
try:
|
||||||
compiled.render(kwargs)
|
compiled.render(kwargs)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
@ -449,7 +449,7 @@ class Template:
|
||||||
"""Render the template and collect an entity filter."""
|
"""Render the template and collect an entity filter."""
|
||||||
assert self.hass and _RENDER_INFO not in self.hass.data
|
assert self.hass and _RENDER_INFO not in self.hass.data
|
||||||
|
|
||||||
render_info = RenderInfo(self)
|
render_info = RenderInfo(self) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if self.is_static:
|
if self.is_static:
|
||||||
|
@ -519,7 +519,7 @@ class Template:
|
||||||
)
|
)
|
||||||
return value if error_value is _SENTINEL else error_value
|
return value if error_value is _SENTINEL else error_value
|
||||||
|
|
||||||
def _ensure_compiled(self):
|
def _ensure_compiled(self) -> "Template":
|
||||||
"""Bind a template to a specific hass instance."""
|
"""Bind a template to a specific hass instance."""
|
||||||
self.ensure_valid()
|
self.ensure_valid()
|
||||||
|
|
||||||
|
@ -527,8 +527,9 @@ class Template:
|
||||||
|
|
||||||
env = self._env
|
env = self._env
|
||||||
|
|
||||||
self._compiled = jinja2.Template.from_code(
|
self._compiled = cast(
|
||||||
env, self._compiled_code, env.globals, None
|
Template,
|
||||||
|
jinja2.Template.from_code(env, self._compiled_code, env.globals, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._compiled
|
return self._compiled
|
||||||
|
@ -553,7 +554,7 @@ class Template:
|
||||||
class AllStates:
|
class AllStates:
|
||||||
"""Class to expose all HA states as attributes."""
|
"""Class to expose all HA states as attributes."""
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(self, hass: HomeAssistantType) -> None:
|
||||||
"""Initialize all states."""
|
"""Initialize all states."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
|
|
||||||
|
@ -607,7 +608,7 @@ class AllStates:
|
||||||
class DomainStates:
|
class DomainStates:
|
||||||
"""Class to expose a specific HA domain as attributes."""
|
"""Class to expose a specific HA domain as attributes."""
|
||||||
|
|
||||||
def __init__(self, hass, domain):
|
def __init__(self, hass: HomeAssistantType, domain: str) -> None:
|
||||||
"""Initialize the domain states."""
|
"""Initialize the domain states."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._domain = domain
|
self._domain = domain
|
||||||
|
@ -652,13 +653,15 @@ class TemplateState(State):
|
||||||
|
|
||||||
# Inheritance is done so functions that check against State keep working
|
# Inheritance is done so functions that check against State keep working
|
||||||
# pylint: disable=super-init-not-called
|
# pylint: disable=super-init-not-called
|
||||||
def __init__(self, hass, state, collect=True):
|
def __init__(
|
||||||
|
self, hass: HomeAssistantType, state: State, collect: bool = True
|
||||||
|
) -> None:
|
||||||
"""Initialize template state."""
|
"""Initialize template state."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._state = state
|
self._state = state
|
||||||
self._collect = collect
|
self._collect = collect
|
||||||
|
|
||||||
def _collect_state(self):
|
def _collect_state(self) -> None:
|
||||||
if self._collect and _RENDER_INFO in self._hass.data:
|
if self._collect and _RENDER_INFO in self._hass.data:
|
||||||
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
|
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
|
||||||
|
|
||||||
|
@ -1411,4 +1414,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
|
|
||||||
_NO_HASS_ENV = TemplateEnvironment(None)
|
_NO_HASS_ENV = TemplateEnvironment(None) # type: ignore[no-untyped-call]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue