Helpers type hint improvements (#44964)

This commit is contained in:
Ville Skyttä 2021-01-09 01:08:34 +02:00 committed by GitHub
parent 3569d92385
commit 3a88a4120e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 45 deletions

View file

@ -556,7 +556,7 @@ def template(value: Optional[Any]) -> template_helper.Template:
template_value = template_helper.Template(str(value)) # type: ignore
try:
template_value.ensure_valid() # type: ignore[no-untyped-call]
template_value.ensure_valid()
return template_value
except TemplateError as ex:
raise vol.Invalid(f"invalid template ({ex})") from ex
@ -574,7 +574,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
template_value = template_helper.Template(str(value)) # type: ignore
try:
template_value.ensure_valid() # type: ignore[no-untyped-call]
template_value.ensure_valid()
return template_value
except TemplateError as ex:
raise vol.Invalid(f"invalid template ({ex})") from ex

View file

@ -26,7 +26,6 @@ from .event import async_call_later, async_track_time_interval
if TYPE_CHECKING:
from .entity import Entity
# mypy: allow-untyped-defs
SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60
@ -81,7 +80,7 @@ class EntityPlatform:
self.platform_name, []
).append(self)
def __repr__(self):
def __repr__(self) -> str:
"""Represent an EntityPlatform."""
return f"<EntityPlatform domain={self.domain} platform_name={self.platform_name} config_entry={self.config_entry}>"
@ -116,7 +115,7 @@ class EntityPlatform:
return self.parallel_updates
async def async_setup(self, platform_config, discovery_info=None):
async def async_setup(self, platform_config, discovery_info=None): # type: ignore[no-untyped-def]
"""Set up the platform from a config file."""
platform = self.platform
hass = self.hass
@ -162,7 +161,7 @@ class EntityPlatform:
platform = self.platform
@callback
def async_create_setup_task():
def async_create_setup_task(): # type: ignore[no-untyped-def]
"""Get task to set up platform."""
return platform.async_setup_entry( # type: ignore
self.hass, config_entry, self._async_schedule_add_entities
@ -218,7 +217,7 @@ class EntityPlatform:
wait_time,
)
async def setup_again(now):
async def setup_again(now): # type: ignore[no-untyped-def]
"""Run setup again."""
self._async_cancel_retry_setup = None
await self._async_setup_platform(async_create_setup_task, tries)
@ -340,7 +339,7 @@ class EntityPlatform:
self.scan_interval,
)
async def _async_add_entity(
async def _async_add_entity( # type: ignore[no-untyped-def]
self, entity, update_before_add, entity_registry, device_registry
):
"""Add an entity to the platform."""
@ -560,7 +559,7 @@ class EntityPlatform:
)
@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
def async_register_entity_service(self, name, schema, func, required_features=None): # type: ignore[no-untyped-def]
"""Register an entity service.
Services will automatically be shared by all platforms of the same domain.

View file

@ -219,7 +219,7 @@ class _ScriptRun:
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
def _changed(self):
def _changed(self) -> None:
if not self._stop.is_set():
self._script._changed() # pylint: disable=protected-access
@ -227,7 +227,7 @@ class _ScriptRun:
# pylint: disable=protected-access
return await self._script._async_get_condition(config)
def _log(self, msg, *args, level=logging.INFO):
def _log(self, msg: str, *args: Any, level: int = logging.INFO) -> None:
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
async def async_run(self) -> None:
@ -257,7 +257,7 @@ class _ScriptRun:
self._log_exception(ex)
raise
def _finish(self):
def _finish(self) -> None:
self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running:
self._script.last_action = None
@ -389,7 +389,7 @@ class _ScriptRun:
async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task():
async def async_cancel_long_task() -> None:
# Stop long task and wait for it to finish.
long_task.cancel()
try:
@ -586,7 +586,7 @@ class _ScriptRun:
else:
del self._variables["repeat"]
async def _async_choose_step(self):
async def _async_choose_step(self) -> None:
"""Choose a sequence."""
# pylint: disable=protected-access
choose_data = await self._script._async_get_choose_data(self._step)
@ -706,7 +706,7 @@ class _QueuedScriptRun(_ScriptRun):
else:
await super().async_run()
def _finish(self):
def _finish(self) -> None:
# pylint: disable=protected-access
if self.lock_acquired:
self._script._queue_lck.release()
@ -868,7 +868,7 @@ class Script:
if choose_data["default"]:
choose_data["default"].update_logger(self._logger)
def _changed(self):
def _changed(self) -> None:
if self._change_listener_job:
self._hass.async_run_hass_job(self._change_listener_job)
@ -898,7 +898,7 @@ class Script:
if self._referenced_devices is not None:
return self._referenced_devices
referenced = set()
referenced: Set[str] = set()
for step in self.sequence:
action = cv.determine_script_action(step)
@ -927,7 +927,7 @@ class Script:
if self._referenced_entities is not None:
return self._referenced_entities
referenced = set()
referenced: Set[str] = set()
for step in self.sequence:
action = cv.determine_script_action(step)
@ -1128,9 +1128,9 @@ class Script:
self._choose_data[step] = choose_data
return choose_data
def _log(self, msg, *args, level=logging.INFO):
def _log(self, msg: str, *args: Any, level: int = logging.INFO) -> None:
msg = f"%s: {msg}"
args = [self.name, *args]
args = (self.name, *args)
if level == _LOG_EXCEPTION:
self._logger.exception(msg, *args)

View file

@ -11,7 +11,7 @@ import math
from operator import attrgetter
import random
import re
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union, cast
from urllib.parse import urlencode as urllib_urlencode
import weakref
@ -38,8 +38,7 @@ from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.thread import ThreadWithException
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
_SENTINEL = object()
@ -140,7 +139,7 @@ def gen_result_wrapper(kls):
if kls is set:
return str(set(self))
return kls.__str__(self)
return cast(str, kls.__str__(self))
return self.render_result
@ -173,7 +172,8 @@ class TupleWrapper(tuple, ResultWrapper):
RESULT_WRAPPERS: Dict[Type, Type] = {
kls: gen_result_wrapper(kls) for kls in (list, dict, set)
kls: gen_result_wrapper(kls) # type: ignore[no-untyped-call]
for kls in (list, dict, set)
}
RESULT_WRAPPERS[tuple] = TupleWrapper
@ -195,15 +195,15 @@ class RenderInfo:
# Will be set sensibly once frozen.
self.filter_lifecycle = _true
self.filter = _true
self._result = None
self._result: Optional[str] = None
self.is_static = False
self.exception = None
self.exception: Optional[TemplateError] = None
self.all_states = False
self.all_states_lifecycle = False
self.domains = set()
self.domains_lifecycle = set()
self.entities = set()
self.rate_limit = None
self.rate_limit: Optional[timedelta] = None
self.has_time = False
def __repr__(self) -> str:
@ -228,7 +228,7 @@ class RenderInfo:
"""Results of the template computation."""
if self.exception is not None:
raise self.exception
return self._result
return cast(str, self._result)
def _freeze_static(self) -> None:
self.is_static = True
@ -288,26 +288,26 @@ class Template:
self.template: str = template.strip()
self._compiled_code = None
self._compiled = None
self._compiled: Optional[Template] = None
self.hass = hass
self.is_static = not is_template_string(template)
@property
def _env(self):
def _env(self) -> "TemplateEnvironment":
if self.hass is None:
return _NO_HASS_ENV
ret = self.hass.data.get(_ENVIRONMENT)
ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
if ret is None:
ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass)
ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass) # type: ignore[no-untyped-call]
return ret
def ensure_valid(self):
def ensure_valid(self) -> None:
"""Return if template is valid."""
if self._compiled_code is not None:
return
try:
self._compiled_code = self._env.compile(self.template)
self._compiled_code = self._env.compile(self.template) # type: ignore[no-untyped-call]
except jinja2.TemplateError as err:
raise TemplateError(err) from err
@ -422,7 +422,7 @@ class Template:
finish_event = asyncio.Event()
def _render_template():
def _render_template() -> None:
try:
compiled.render(kwargs)
except TimeoutError:
@ -449,7 +449,7 @@ class Template:
"""Render the template and collect an entity filter."""
assert self.hass and _RENDER_INFO not in self.hass.data
render_info = RenderInfo(self)
render_info = RenderInfo(self) # type: ignore[no-untyped-call]
# pylint: disable=protected-access
if self.is_static:
@ -519,7 +519,7 @@ class Template:
)
return value if error_value is _SENTINEL else error_value
def _ensure_compiled(self):
def _ensure_compiled(self) -> "Template":
"""Bind a template to a specific hass instance."""
self.ensure_valid()
@ -527,8 +527,9 @@ class Template:
env = self._env
self._compiled = jinja2.Template.from_code(
env, self._compiled_code, env.globals, None
self._compiled = cast(
Template,
jinja2.Template.from_code(env, self._compiled_code, env.globals, None),
)
return self._compiled
@ -553,7 +554,7 @@ class Template:
class AllStates:
"""Class to expose all HA states as attributes."""
def __init__(self, hass):
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize all states."""
self._hass = hass
@ -607,7 +608,7 @@ class AllStates:
class DomainStates:
"""Class to expose a specific HA domain as attributes."""
def __init__(self, hass, domain):
def __init__(self, hass: HomeAssistantType, domain: str) -> None:
"""Initialize the domain states."""
self._hass = hass
self._domain = domain
@ -652,13 +653,15 @@ class TemplateState(State):
# Inheritance is done so functions that check against State keep working
# pylint: disable=super-init-not-called
def __init__(self, hass, state, collect=True):
def __init__(
self, hass: HomeAssistantType, state: State, collect: bool = True
) -> None:
"""Initialize template state."""
self._hass = hass
self._state = state
self._collect = collect
def _collect_state(self):
def _collect_state(self) -> None:
if self._collect and _RENDER_INFO in self._hass.data:
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
@ -1411,4 +1414,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
return cached
_NO_HASS_ENV = TemplateEnvironment(None)
_NO_HASS_ENV = TemplateEnvironment(None) # type: ignore[no-untyped-call]