Type hint improvements (#33082)

This commit is contained in:
Ville Skyttä 2020-04-17 21:33:58 +03:00 committed by GitHub
parent f04be61f6f
commit 267d98b5eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 124 additions and 73 deletions

View file

@ -486,7 +486,7 @@ class Event:
def __init__( def __init__(
self, self,
event_type: str, event_type: str,
data: Optional[Dict] = None, data: Optional[Dict[str, Any]] = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None, time_fired: Optional[int] = None,
context: Optional[Context] = None, context: Optional[Context] = None,
@ -550,9 +550,7 @@ class EventBus:
@property @property
def listeners(self) -> Dict[str, int]: def listeners(self) -> Dict[str, int]:
"""Return dictionary with events and the number of listeners.""" """Return dictionary with events and the number of listeners."""
return run_callback_threadsafe( # type: ignore return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
self._hass.loop, self.async_listeners
).result()
def fire( def fire(
self, self,
@ -852,7 +850,7 @@ class StateMachine:
future = run_callback_threadsafe( future = run_callback_threadsafe(
self._loop, self.async_entity_ids, domain_filter self._loop, self.async_entity_ids, domain_filter
) )
return future.result() # type: ignore return future.result()
@callback @callback
def async_entity_ids(self, domain_filter: Optional[str] = None) -> List[str]: def async_entity_ids(self, domain_filter: Optional[str] = None) -> List[str]:
@ -873,9 +871,7 @@ class StateMachine:
def all(self) -> List[State]: def all(self) -> List[State]:
"""Create a list of all states.""" """Create a list of all states."""
return run_callback_threadsafe( # type: ignore return run_callback_threadsafe(self._loop, self.async_all).result()
self._loop, self.async_all
).result()
@callback @callback
def async_all(self) -> List[State]: def async_all(self) -> List[State]:
@ -905,7 +901,7 @@ class StateMachine:
Returns boolean to indicate if an entity was removed. Returns boolean to indicate if an entity was removed.
""" """
return run_callback_threadsafe( # type: ignore return run_callback_threadsafe(
self._loop, self.async_remove, entity_id self._loop, self.async_remove, entity_id
).result() ).result()
@ -1064,9 +1060,7 @@ class ServiceRegistry:
@property @property
def services(self) -> Dict[str, Dict[str, Service]]: def services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services.""" """Return dictionary with per domain a list of available services."""
return run_callback_threadsafe( # type: ignore return run_callback_threadsafe(self._hass.loop, self.async_services).result()
self._hass.loop, self.async_services
).result()
@callback @callback
def async_services(self) -> Dict[str, Dict[str, Service]]: def async_services(self) -> Dict[str, Dict[str, Service]]:

View file

@ -146,19 +146,16 @@ def numeric_state(
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
) -> bool: ) -> bool:
"""Test a numeric state condition.""" """Test a numeric state condition."""
return cast( return run_callback_threadsafe(
bool, hass.loop,
run_callback_threadsafe( async_numeric_state,
hass.loop, hass,
async_numeric_state, entity,
hass, below,
entity, above,
below, value_template,
above, variables,
value_template, ).result()
variables,
).result(),
)
def async_numeric_state( def async_numeric_state(
@ -353,12 +350,9 @@ def template(
hass: HomeAssistant, value_template: Template, variables: TemplateVarsType = None hass: HomeAssistant, value_template: Template, variables: TemplateVarsType = None
) -> bool: ) -> bool:
"""Test if template condition matches.""" """Test if template condition matches."""
return cast( return run_callback_threadsafe(
bool, hass.loop, async_template, hass, value_template, variables
run_callback_threadsafe( ).result()
hass.loop, async_template, hass, value_template, variables
).result(),
)
def async_template( def async_template(

View file

@ -10,11 +10,10 @@ from typing import Any, Callable, Collection, Dict, Optional, Union
from homeassistant import core, setup from homeassistant import core, setup
from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import DEPENDENCY_BLACKLIST, bind_hass from homeassistant.loader import DEPENDENCY_BLACKLIST, bind_hass
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-defs, no-check-untyped-defs
EVENT_LOAD_PLATFORM = "load_platform.{}" EVENT_LOAD_PLATFORM = "load_platform.{}"
ATTR_PLATFORM = "platform" ATTR_PLATFORM = "platform"
@ -56,13 +55,29 @@ def async_listen(
@bind_hass @bind_hass
def discover(hass, service, discovered, component, hass_config): def discover(
hass: core.HomeAssistant,
service: str,
discovered: DiscoveryInfoType,
component: str,
hass_config: ConfigType,
) -> None:
"""Fire discovery event. Can ensure a component is loaded.""" """Fire discovery event. Can ensure a component is loaded."""
hass.add_job(async_discover(hass, service, discovered, component, hass_config)) hass.add_job(
async_discover( # type: ignore
hass, service, discovered, component, hass_config
)
)
@bind_hass @bind_hass
async def async_discover(hass, service, discovered, component, hass_config): async def async_discover(
hass: core.HomeAssistant,
service: str,
discovered: Optional[DiscoveryInfoType],
component: Optional[str],
hass_config: ConfigType,
) -> None:
"""Fire discovery event. Can ensure a component is loaded.""" """Fire discovery event. Can ensure a component is loaded."""
if component in DEPENDENCY_BLACKLIST: if component in DEPENDENCY_BLACKLIST:
raise HomeAssistantError(f"Cannot discover the {component} component.") raise HomeAssistantError(f"Cannot discover the {component} component.")
@ -70,7 +85,7 @@ async def async_discover(hass, service, discovered, component, hass_config):
if component is not None and component not in hass.config.components: if component is not None and component not in hass.config.components:
await setup.async_setup_component(hass, component, hass_config) await setup.async_setup_component(hass, component, hass_config)
data = {ATTR_SERVICE: service} data: Dict[str, Any] = {ATTR_SERVICE: service}
if discovered is not None: if discovered is not None:
data[ATTR_DISCOVERED] = discovered data[ATTR_DISCOVERED] = discovered
@ -117,7 +132,13 @@ def async_listen_platform(
@bind_hass @bind_hass
def load_platform(hass, component, platform, discovered, hass_config): def load_platform(
hass: core.HomeAssistant,
component: str,
platform: str,
discovered: DiscoveryInfoType,
hass_config: ConfigType,
) -> None:
"""Load a component and platform dynamically. """Load a component and platform dynamically.
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
@ -129,12 +150,20 @@ def load_platform(hass, component, platform, discovered, hass_config):
Use `listen_platform` to register a callback for these events. Use `listen_platform` to register a callback for these events.
""" """
hass.add_job( hass.add_job(
async_load_platform(hass, component, platform, discovered, hass_config) async_load_platform( # type: ignore
hass, component, platform, discovered, hass_config
)
) )
@bind_hass @bind_hass
async def async_load_platform(hass, component, platform, discovered, hass_config): async def async_load_platform(
hass: core.HomeAssistant,
component: str,
platform: str,
discovered: DiscoveryInfoType,
hass_config: ConfigType,
) -> None:
"""Load a component and platform dynamically. """Load a component and platform dynamically.
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
@ -164,7 +193,7 @@ async def async_load_platform(hass, component, platform, discovered, hass_config
if not setup_success: if not setup_success:
return return
data = { data: Dict[str, Any] = {
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component), ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
ATTR_PLATFORM: platform, ATTR_PLATFORM: platform,
} }

View file

@ -35,7 +35,7 @@ from homeassistant.helpers.entity_registry import (
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-defs, no-check-untyped-defs, no-warn-return-any # mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10 SLOW_UPDATE_WARNING = 10

View file

@ -191,7 +191,7 @@ class EntityComponent:
This method must be run in the event loop. This method must be run in the event loop.
""" """
return await service.async_extract_entities( # type: ignore return await service.async_extract_entities(
self.hass, self.entities, service_call, expand_group self.hass, self.entities, service_call, expand_group
) )

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
from functools import partial, wraps from functools import partial, wraps
import logging import logging
from typing import Callable from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
import voluptuous as vol import voluptuous as vol
@ -22,9 +22,10 @@ from homeassistant.exceptions import (
Unauthorized, Unauthorized,
UnknownUser, UnknownUser,
) )
from homeassistant.helpers import template, typing from homeassistant.helpers import template
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
from homeassistant.loader import async_get_integration, bind_hass from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from homeassistant.util.yaml.loader import JSON_TYPE from homeassistant.util.yaml.loader import JSON_TYPE
@ -42,8 +43,12 @@ SERVICE_DESCRIPTION_CACHE = "service_description_cache"
@bind_hass @bind_hass
def call_from_config( def call_from_config(
hass, config, blocking=False, variables=None, validate_config=True hass: HomeAssistantType,
): config: ConfigType,
blocking: bool = False,
variables: TemplateVarsType = None,
validate_config: bool = True,
) -> None:
"""Call a service based on a config hash.""" """Call a service based on a config hash."""
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
async_call_from_config(hass, config, blocking, variables, validate_config), async_call_from_config(hass, config, blocking, variables, validate_config),
@ -53,8 +58,13 @@ def call_from_config(
@bind_hass @bind_hass
async def async_call_from_config( async def async_call_from_config(
hass, config, blocking=False, variables=None, validate_config=True, context=None hass: HomeAssistantType,
): config: ConfigType,
blocking: bool = False,
variables: TemplateVarsType = None,
validate_config: bool = True,
context: Optional[ha.Context] = None,
) -> None:
"""Call a service based on a config hash.""" """Call a service based on a config hash."""
try: try:
parms = async_prepare_call_from_config(hass, config, variables, validate_config) parms = async_prepare_call_from_config(hass, config, variables, validate_config)
@ -68,7 +78,12 @@ async def async_call_from_config(
@ha.callback @ha.callback
@bind_hass @bind_hass
def async_prepare_call_from_config(hass, config, variables=None, validate_config=False): def async_prepare_call_from_config(
hass: HomeAssistantType,
config: ConfigType,
variables: TemplateVarsType = None,
validate_config: bool = False,
) -> Tuple[str, str, Dict[str, Any]]:
"""Prepare to call a service based on a config hash.""" """Prepare to call a service based on a config hash."""
if validate_config: if validate_config:
try: try:
@ -113,7 +128,9 @@ def async_prepare_call_from_config(hass, config, variables=None, validate_config
@bind_hass @bind_hass
def extract_entity_ids(hass, service_call, expand_group=True): def extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
"""Extract a list of entity ids from a service call. """Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
@ -124,7 +141,12 @@ def extract_entity_ids(hass, service_call, expand_group=True):
@bind_hass @bind_hass
async def async_extract_entities(hass, entities, service_call, expand_group=True): async def async_extract_entities(
hass: HomeAssistantType,
entities: Iterable[Entity],
service_call: ha.ServiceCall,
expand_group: bool = True,
) -> List[Entity]:
"""Extract a list of entity objects from a service call. """Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
@ -158,7 +180,9 @@ async def async_extract_entities(hass, entities, service_call, expand_group=True
@bind_hass @bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True): async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
"""Extract a list of entity ids from a service call. """Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
@ -166,7 +190,7 @@ async def async_extract_entity_ids(hass, service_call, expand_group=True):
entity_ids = service_call.data.get(ATTR_ENTITY_ID) entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID) area_ids = service_call.data.get(ATTR_AREA_ID)
extracted = set() extracted: Set[str] = set()
if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in ( if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in (
None, None,
@ -226,7 +250,9 @@ async def _load_services_file(hass: HomeAssistantType, domain: str) -> JSON_TYPE
@bind_hass @bind_hass
async def async_get_all_descriptions(hass): async def async_get_all_descriptions(
hass: HomeAssistantType,
) -> Dict[str, Dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls.""" """Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
format_cache_key = "{}.{}".format format_cache_key = "{}.{}".format
@ -253,7 +279,7 @@ async def async_get_all_descriptions(hass):
loaded[domain] = content loaded[domain] = content
# Build response # Build response
descriptions = {} descriptions: Dict[str, Dict[str, Any]] = {}
for domain in services: for domain in services:
descriptions[domain] = {} descriptions[domain] = {}
@ -281,7 +307,9 @@ async def async_get_all_descriptions(hass):
@ha.callback @ha.callback
@bind_hass @bind_hass
def async_set_service_schema(hass, domain, service, schema): def async_set_service_schema(
hass: HomeAssistantType, domain: str, service: str, schema: Dict[str, Any]
) -> None:
"""Register a description for a service.""" """Register a description for a service."""
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {}) hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
@ -454,7 +482,7 @@ async def _handle_entity_call(hass, entity, func, data, context):
@bind_hass @bind_hass
@ha.callback @ha.callback
def async_register_admin_service( def async_register_admin_service(
hass: typing.HomeAssistantType, hass: HomeAssistantType,
domain: str, domain: str,
service: str, service: str,
service_func: Callable, service_func: Callable,

View file

@ -51,7 +51,7 @@ _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{")
@bind_hass @bind_hass
def attach(hass, obj): def attach(hass: HomeAssistantType, obj: Any) -> None:
"""Recursively attach hass to all template instances in list and dict.""" """Recursively attach hass to all template instances in list and dict."""
if isinstance(obj, list): if isinstance(obj, list):
for child in obj: for child in obj:
@ -63,7 +63,7 @@ def attach(hass, obj):
obj.hass = hass obj.hass = hass
def render_complex(value, variables=None): def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
"""Recursive template creator helper function.""" """Recursive template creator helper function."""
if isinstance(value, list): if isinstance(value, list):
return [render_complex(item, variables) for item in value] return [render_complex(item, variables) for item in value]
@ -307,11 +307,11 @@ class Template:
and self.hass == other.hass and self.hass == other.hass
) )
def __hash__(self): def __hash__(self) -> int:
"""Hash code for template.""" """Hash code for template."""
return hash(self.template) return hash(self.template)
def __repr__(self): def __repr__(self) -> str:
"""Representation of Template.""" """Representation of Template."""
return 'Template("' + self.template + '")' return 'Template("' + self.template + '")'
@ -333,7 +333,7 @@ class AllStates:
raise TemplateError(f"Invalid domain name '{name}'") raise TemplateError(f"Invalid domain name '{name}'")
return DomainStates(self._hass, name) return DomainStates(self._hass, name)
def _collect_all(self): def _collect_all(self) -> None:
render_info = self._hass.data.get(_RENDER_INFO) render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None: if render_info is not None:
# pylint: disable=protected-access # pylint: disable=protected-access
@ -349,7 +349,7 @@ class AllStates:
) )
) )
def __len__(self): def __len__(self) -> int:
"""Return number of states.""" """Return number of states."""
self._collect_all() self._collect_all()
return len(self._hass.states.async_entity_ids()) return len(self._hass.states.async_entity_ids())
@ -359,7 +359,7 @@ class AllStates:
state = _get_state(self._hass, entity_id) state = _get_state(self._hass, entity_id)
return STATE_UNKNOWN if state is None else state.state return STATE_UNKNOWN if state is None else state.state
def __repr__(self): def __repr__(self) -> str:
"""Representation of All States.""" """Representation of All States."""
return "<template AllStates>" return "<template AllStates>"
@ -455,19 +455,21 @@ class TemplateState(State):
return f"<template {rep[1:]}" return f"<template {rep[1:]}"
def _collect_state(hass, entity_id): def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
entity_collect = hass.data.get(_RENDER_INFO) entity_collect = hass.data.get(_RENDER_INFO)
if entity_collect is not None: if entity_collect is not None:
# pylint: disable=protected-access # pylint: disable=protected-access
entity_collect._entities.append(entity_id) entity_collect._entities.append(entity_id)
def _wrap_state(hass, state): def _wrap_state(
hass: HomeAssistantType, state: Optional[State]
) -> Optional[TemplateState]:
"""Wrap a state.""" """Wrap a state."""
return None if state is None else TemplateState(hass, state) return None if state is None else TemplateState(hass, state)
def _get_state(hass, entity_id): def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
if state is None: if state is None:
# Only need to collect if none, if not none collect first actual # Only need to collect if none, if not none collect first actual
@ -477,7 +479,9 @@ def _get_state(hass, entity_id):
return _wrap_state(hass, state) return _wrap_state(hass, state)
def _resolve_state(hass, entity_id_or_state): def _resolve_state(
hass: HomeAssistantType, entity_id_or_state: Any
) -> Union[State, TemplateState, None]:
"""Return state or entity_id if given.""" """Return state or entity_id if given."""
if isinstance(entity_id_or_state, State): if isinstance(entity_id_or_state, State):
return entity_id_or_state return entity_id_or_state

View file

@ -6,10 +6,12 @@ import functools
import logging import logging
import threading import threading
from traceback import extract_stack from traceback import extract_stack
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine, TypeVar
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
T = TypeVar("T")
def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None: def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None:
"""Submit a coroutine object to a given event loop. """Submit a coroutine object to a given event loop.
@ -33,8 +35,8 @@ def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None:
def run_callback_threadsafe( def run_callback_threadsafe(
loop: AbstractEventLoop, callback: Callable, *args: Any loop: AbstractEventLoop, callback: Callable[..., T], *args: Any
) -> concurrent.futures.Future: ) -> "concurrent.futures.Future[T]":
"""Submit a callback object to a given event loop. """Submit a callback object to a given event loop.
Return a concurrent.futures.Future to access the result. Return a concurrent.futures.Future to access the result.