Helpers type hint improvements (#44964)

This commit is contained in:
Ville Skyttä 2021-01-09 01:08:34 +02:00 committed by GitHub
parent 3569d92385
commit 3a88a4120e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 45 deletions

View file

@ -556,7 +556,7 @@ def template(value: Optional[Any]) -> template_helper.Template:
template_value = template_helper.Template(str(value)) # type: ignore template_value = template_helper.Template(str(value)) # type: ignore
try: try:
template_value.ensure_valid() # type: ignore[no-untyped-call] template_value.ensure_valid()
return template_value return template_value
except TemplateError as ex: except TemplateError as ex:
raise vol.Invalid(f"invalid template ({ex})") from ex raise vol.Invalid(f"invalid template ({ex})") from ex
@ -574,7 +574,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
template_value = template_helper.Template(str(value)) # type: ignore template_value = template_helper.Template(str(value)) # type: ignore
try: try:
template_value.ensure_valid() # type: ignore[no-untyped-call] template_value.ensure_valid()
return template_value return template_value
except TemplateError as ex: except TemplateError as ex:
raise vol.Invalid(f"invalid template ({ex})") from ex raise vol.Invalid(f"invalid template ({ex})") from ex

View file

@ -26,7 +26,6 @@ from .event import async_call_later, async_track_time_interval
if TYPE_CHECKING: if TYPE_CHECKING:
from .entity import Entity from .entity import Entity
# mypy: allow-untyped-defs
SLOW_SETUP_WARNING = 10 SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60 SLOW_SETUP_MAX_WAIT = 60
@ -81,7 +80,7 @@ class EntityPlatform:
self.platform_name, [] self.platform_name, []
).append(self) ).append(self)
def __repr__(self): def __repr__(self) -> str:
"""Represent an EntityPlatform.""" """Represent an EntityPlatform."""
return f"<EntityPlatform domain={self.domain} platform_name={self.platform_name} config_entry={self.config_entry}>" return f"<EntityPlatform domain={self.domain} platform_name={self.platform_name} config_entry={self.config_entry}>"
@ -116,7 +115,7 @@ class EntityPlatform:
return self.parallel_updates return self.parallel_updates
async def async_setup(self, platform_config, discovery_info=None): async def async_setup(self, platform_config, discovery_info=None): # type: ignore[no-untyped-def]
"""Set up the platform from a config file.""" """Set up the platform from a config file."""
platform = self.platform platform = self.platform
hass = self.hass hass = self.hass
@ -162,7 +161,7 @@ class EntityPlatform:
platform = self.platform platform = self.platform
@callback @callback
def async_create_setup_task(): def async_create_setup_task(): # type: ignore[no-untyped-def]
"""Get task to set up platform.""" """Get task to set up platform."""
return platform.async_setup_entry( # type: ignore return platform.async_setup_entry( # type: ignore
self.hass, config_entry, self._async_schedule_add_entities self.hass, config_entry, self._async_schedule_add_entities
@ -218,7 +217,7 @@ class EntityPlatform:
wait_time, wait_time,
) )
async def setup_again(now): async def setup_again(now): # type: ignore[no-untyped-def]
"""Run setup again.""" """Run setup again."""
self._async_cancel_retry_setup = None self._async_cancel_retry_setup = None
await self._async_setup_platform(async_create_setup_task, tries) await self._async_setup_platform(async_create_setup_task, tries)
@ -340,7 +339,7 @@ class EntityPlatform:
self.scan_interval, self.scan_interval,
) )
async def _async_add_entity( async def _async_add_entity( # type: ignore[no-untyped-def]
self, entity, update_before_add, entity_registry, device_registry self, entity, update_before_add, entity_registry, device_registry
): ):
"""Add an entity to the platform.""" """Add an entity to the platform."""
@ -560,7 +559,7 @@ class EntityPlatform:
) )
@callback @callback
def async_register_entity_service(self, name, schema, func, required_features=None): def async_register_entity_service(self, name, schema, func, required_features=None): # type: ignore[no-untyped-def]
"""Register an entity service. """Register an entity service.
Services will automatically be shared by all platforms of the same domain. Services will automatically be shared by all platforms of the same domain.

View file

@ -219,7 +219,7 @@ class _ScriptRun:
self._stop = asyncio.Event() self._stop = asyncio.Event()
self._stopped = asyncio.Event() self._stopped = asyncio.Event()
def _changed(self): def _changed(self) -> None:
if not self._stop.is_set(): if not self._stop.is_set():
self._script._changed() # pylint: disable=protected-access self._script._changed() # pylint: disable=protected-access
@ -227,7 +227,7 @@ class _ScriptRun:
# pylint: disable=protected-access # pylint: disable=protected-access
return await self._script._async_get_condition(config) return await self._script._async_get_condition(config)
def _log(self, msg, *args, level=logging.INFO): def _log(self, msg: str, *args: Any, level: int = logging.INFO) -> None:
self._script._log(msg, *args, level=level) # pylint: disable=protected-access self._script._log(msg, *args, level=level) # pylint: disable=protected-access
async def async_run(self) -> None: async def async_run(self) -> None:
@ -257,7 +257,7 @@ class _ScriptRun:
self._log_exception(ex) self._log_exception(ex)
raise raise
def _finish(self): def _finish(self) -> None:
self._script._runs.remove(self) # pylint: disable=protected-access self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running: if not self._script.is_running:
self._script.last_action = None self._script.last_action = None
@ -389,7 +389,7 @@ class _ScriptRun:
async def _async_run_long_action(self, long_task): async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request.""" """Run a long task while monitoring for stop request."""
async def async_cancel_long_task(): async def async_cancel_long_task() -> None:
# Stop long task and wait for it to finish. # Stop long task and wait for it to finish.
long_task.cancel() long_task.cancel()
try: try:
@ -586,7 +586,7 @@ class _ScriptRun:
else: else:
del self._variables["repeat"] del self._variables["repeat"]
async def _async_choose_step(self): async def _async_choose_step(self) -> None:
"""Choose a sequence.""" """Choose a sequence."""
# pylint: disable=protected-access # pylint: disable=protected-access
choose_data = await self._script._async_get_choose_data(self._step) choose_data = await self._script._async_get_choose_data(self._step)
@ -706,7 +706,7 @@ class _QueuedScriptRun(_ScriptRun):
else: else:
await super().async_run() await super().async_run()
def _finish(self): def _finish(self) -> None:
# pylint: disable=protected-access # pylint: disable=protected-access
if self.lock_acquired: if self.lock_acquired:
self._script._queue_lck.release() self._script._queue_lck.release()
@ -868,7 +868,7 @@ class Script:
if choose_data["default"]: if choose_data["default"]:
choose_data["default"].update_logger(self._logger) choose_data["default"].update_logger(self._logger)
def _changed(self): def _changed(self) -> None:
if self._change_listener_job: if self._change_listener_job:
self._hass.async_run_hass_job(self._change_listener_job) self._hass.async_run_hass_job(self._change_listener_job)
@ -898,7 +898,7 @@ class Script:
if self._referenced_devices is not None: if self._referenced_devices is not None:
return self._referenced_devices return self._referenced_devices
referenced = set() referenced: Set[str] = set()
for step in self.sequence: for step in self.sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
@ -927,7 +927,7 @@ class Script:
if self._referenced_entities is not None: if self._referenced_entities is not None:
return self._referenced_entities return self._referenced_entities
referenced = set() referenced: Set[str] = set()
for step in self.sequence: for step in self.sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
@ -1128,9 +1128,9 @@ class Script:
self._choose_data[step] = choose_data self._choose_data[step] = choose_data
return choose_data return choose_data
def _log(self, msg, *args, level=logging.INFO): def _log(self, msg: str, *args: Any, level: int = logging.INFO) -> None:
msg = f"%s: {msg}" msg = f"%s: {msg}"
args = [self.name, *args] args = (self.name, *args)
if level == _LOG_EXCEPTION: if level == _LOG_EXCEPTION:
self._logger.exception(msg, *args) self._logger.exception(msg, *args)

View file

@ -11,7 +11,7 @@ import math
from operator import attrgetter from operator import attrgetter
import random import random
import re import re
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union from typing import Any, Dict, Generator, Iterable, Optional, Type, Union, cast
from urllib.parse import urlencode as urllib_urlencode from urllib.parse import urlencode as urllib_urlencode
import weakref import weakref
@ -38,8 +38,7 @@ from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.thread import ThreadWithException from homeassistant.util.thread import ThreadWithException
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_SENTINEL = object() _SENTINEL = object()
@ -140,7 +139,7 @@ def gen_result_wrapper(kls):
if kls is set: if kls is set:
return str(set(self)) return str(set(self))
return kls.__str__(self) return cast(str, kls.__str__(self))
return self.render_result return self.render_result
@ -173,7 +172,8 @@ class TupleWrapper(tuple, ResultWrapper):
RESULT_WRAPPERS: Dict[Type, Type] = { RESULT_WRAPPERS: Dict[Type, Type] = {
kls: gen_result_wrapper(kls) for kls in (list, dict, set) kls: gen_result_wrapper(kls) # type: ignore[no-untyped-call]
for kls in (list, dict, set)
} }
RESULT_WRAPPERS[tuple] = TupleWrapper RESULT_WRAPPERS[tuple] = TupleWrapper
@ -195,15 +195,15 @@ class RenderInfo:
# Will be set sensibly once frozen. # Will be set sensibly once frozen.
self.filter_lifecycle = _true self.filter_lifecycle = _true
self.filter = _true self.filter = _true
self._result = None self._result: Optional[str] = None
self.is_static = False self.is_static = False
self.exception = None self.exception: Optional[TemplateError] = None
self.all_states = False self.all_states = False
self.all_states_lifecycle = False self.all_states_lifecycle = False
self.domains = set() self.domains = set()
self.domains_lifecycle = set() self.domains_lifecycle = set()
self.entities = set() self.entities = set()
self.rate_limit = None self.rate_limit: Optional[timedelta] = None
self.has_time = False self.has_time = False
def __repr__(self) -> str: def __repr__(self) -> str:
@ -228,7 +228,7 @@ class RenderInfo:
"""Results of the template computation.""" """Results of the template computation."""
if self.exception is not None: if self.exception is not None:
raise self.exception raise self.exception
return self._result return cast(str, self._result)
def _freeze_static(self) -> None: def _freeze_static(self) -> None:
self.is_static = True self.is_static = True
@ -288,26 +288,26 @@ class Template:
self.template: str = template.strip() self.template: str = template.strip()
self._compiled_code = None self._compiled_code = None
self._compiled = None self._compiled: Optional[Template] = None
self.hass = hass self.hass = hass
self.is_static = not is_template_string(template) self.is_static = not is_template_string(template)
@property @property
def _env(self): def _env(self) -> "TemplateEnvironment":
if self.hass is None: if self.hass is None:
return _NO_HASS_ENV return _NO_HASS_ENV
ret = self.hass.data.get(_ENVIRONMENT) ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
if ret is None: if ret is None:
ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass) ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass) # type: ignore[no-untyped-call]
return ret return ret
def ensure_valid(self): def ensure_valid(self) -> None:
"""Return if template is valid.""" """Return if template is valid."""
if self._compiled_code is not None: if self._compiled_code is not None:
return return
try: try:
self._compiled_code = self._env.compile(self.template) self._compiled_code = self._env.compile(self.template) # type: ignore[no-untyped-call]
except jinja2.TemplateError as err: except jinja2.TemplateError as err:
raise TemplateError(err) from err raise TemplateError(err) from err
@ -422,7 +422,7 @@ class Template:
finish_event = asyncio.Event() finish_event = asyncio.Event()
def _render_template(): def _render_template() -> None:
try: try:
compiled.render(kwargs) compiled.render(kwargs)
except TimeoutError: except TimeoutError:
@ -449,7 +449,7 @@ class Template:
"""Render the template and collect an entity filter.""" """Render the template and collect an entity filter."""
assert self.hass and _RENDER_INFO not in self.hass.data assert self.hass and _RENDER_INFO not in self.hass.data
render_info = RenderInfo(self) render_info = RenderInfo(self) # type: ignore[no-untyped-call]
# pylint: disable=protected-access # pylint: disable=protected-access
if self.is_static: if self.is_static:
@ -519,7 +519,7 @@ class Template:
) )
return value if error_value is _SENTINEL else error_value return value if error_value is _SENTINEL else error_value
def _ensure_compiled(self): def _ensure_compiled(self) -> "Template":
"""Bind a template to a specific hass instance.""" """Bind a template to a specific hass instance."""
self.ensure_valid() self.ensure_valid()
@ -527,8 +527,9 @@ class Template:
env = self._env env = self._env
self._compiled = jinja2.Template.from_code( self._compiled = cast(
env, self._compiled_code, env.globals, None Template,
jinja2.Template.from_code(env, self._compiled_code, env.globals, None),
) )
return self._compiled return self._compiled
@ -553,7 +554,7 @@ class Template:
class AllStates: class AllStates:
"""Class to expose all HA states as attributes.""" """Class to expose all HA states as attributes."""
def __init__(self, hass): def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize all states.""" """Initialize all states."""
self._hass = hass self._hass = hass
@ -607,7 +608,7 @@ class AllStates:
class DomainStates: class DomainStates:
"""Class to expose a specific HA domain as attributes.""" """Class to expose a specific HA domain as attributes."""
def __init__(self, hass, domain): def __init__(self, hass: HomeAssistantType, domain: str) -> None:
"""Initialize the domain states.""" """Initialize the domain states."""
self._hass = hass self._hass = hass
self._domain = domain self._domain = domain
@ -652,13 +653,15 @@ class TemplateState(State):
# Inheritance is done so functions that check against State keep working # Inheritance is done so functions that check against State keep working
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
def __init__(self, hass, state, collect=True): def __init__(
self, hass: HomeAssistantType, state: State, collect: bool = True
) -> None:
"""Initialize template state.""" """Initialize template state."""
self._hass = hass self._hass = hass
self._state = state self._state = state
self._collect = collect self._collect = collect
def _collect_state(self): def _collect_state(self) -> None:
if self._collect and _RENDER_INFO in self._hass.data: if self._collect and _RENDER_INFO in self._hass.data:
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id) self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
@ -1411,4 +1414,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
return cached return cached
_NO_HASS_ENV = TemplateEnvironment(None) _NO_HASS_ENV = TemplateEnvironment(None) # type: ignore[no-untyped-call]