Add typing to homeassistant/*.py and homeassistant/util/ (#15569)

* Add typing to homeassistant/*.py and homeassistant/util/

* Fix wrong merge

* Restore iterable in OrderedSet

* Fix tests
This commit is contained in:
Andrey 2018-07-23 11:24:39 +03:00 committed by Paulus Schoutsen
parent b7c336a687
commit 140a874917
27 changed files with 532 additions and 384 deletions

View file

@ -20,7 +20,7 @@ from homeassistant.const import (
) )
def attempt_use_uvloop(): def attempt_use_uvloop() -> None:
"""Attempt to use uvloop.""" """Attempt to use uvloop."""
import asyncio import asyncio
@ -280,7 +280,7 @@ def setup_and_run_hass(config_dir: str,
# Imported here to avoid importing asyncio before monkey patch # Imported here to avoid importing asyncio before monkey patch
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
def open_browser(event): def open_browser(_: Any) -> None:
"""Open the web interface in a browser.""" """Open the web interface in a browser."""
if hass.config.api is not None: # type: ignore if hass.config.api is not None: # type: ignore
import webbrowser import webbrowser

View file

@ -221,8 +221,8 @@ async def async_from_config_file(config_path: str,
@core.callback @core.callback
def async_enable_logging(hass: core.HomeAssistant, def async_enable_logging(hass: core.HomeAssistant,
verbose: bool = False, verbose: bool = False,
log_rotate_days=None, log_rotate_days: Optional[int] = None,
log_file=None, log_file: Optional[str] = None,
log_no_color: bool = False) -> None: log_no_color: bool = False) -> None:
"""Set up the logging. """Set up the logging.
@ -291,7 +291,7 @@ def async_enable_logging(hass: core.HomeAssistant,
async_handler = AsyncHandler(hass.loop, err_handler) async_handler = AsyncHandler(hass.loop, err_handler)
async def async_stop_async_handler(event): async def async_stop_async_handler(_: Any) -> None:
"""Cleanup async handler.""" """Cleanup async handler."""
logging.getLogger('').removeHandler(async_handler) # type: ignore logging.getLogger('').removeHandler(async_handler) # type: ignore
await async_handler.async_close(blocking=True) await async_handler.async_close(blocking=True)

View file

@ -9,7 +9,7 @@ import logging
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
from typing import Optional
from homeassistant.auth.util import generate_secret from homeassistant.auth.util import generate_secret
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
@ -241,7 +241,7 @@ class RachioIro:
# Only enabled zones # Only enabled zones
return [z for z in self._zones if z[KEY_ENABLED]] return [z for z in self._zones if z[KEY_ENABLED]]
def get_zone(self, zone_id) -> dict or None: def get_zone(self, zone_id) -> Optional[dict]:
"""Return the zone with the given ID.""" """Return the zone with the given ID."""
for zone in self.list_zones(include_disabled=True): for zone in self.list_zones(include_disabled=True):
if zone[KEY_ID] == zone_id: if zone[KEY_ID] == zone_id:

View file

@ -7,8 +7,9 @@ import os
import re import re
import shutil import shutil
# pylint: disable=unused-import # pylint: disable=unused-import
from typing import Any, Tuple, Optional # noqa: F401 from typing import ( # noqa: F401
Any, Tuple, Optional, Dict, List, Union, Callable)
from types import ModuleType
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -21,7 +22,7 @@ from homeassistant.const import (
CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS, CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS,
__version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB, __version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB,
CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS, CONF_TYPE) CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS, CONF_TYPE)
from homeassistant.core import callback, DOMAIN as CONF_CORE from homeassistant.core import callback, DOMAIN as CONF_CORE, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import get_component, get_platform from homeassistant.loader import get_component, get_platform
from homeassistant.util.yaml import load_yaml, SECRET_YAML from homeassistant.util.yaml import load_yaml, SECRET_YAML
@ -193,7 +194,7 @@ def ensure_config_exists(config_dir: str, detect_location: bool = True)\
return config_path return config_path
def create_default_config(config_dir: str, detect_location=True)\ def create_default_config(config_dir: str, detect_location: bool = True)\
-> Optional[str]: -> Optional[str]:
"""Create a default configuration file in given configuration directory. """Create a default configuration file in given configuration directory.
@ -276,7 +277,7 @@ def create_default_config(config_dir: str, detect_location=True)\
return None return None
async def async_hass_config_yaml(hass): async def async_hass_config_yaml(hass: HomeAssistant) -> Dict:
"""Load YAML from a Home Assistant configuration file. """Load YAML from a Home Assistant configuration file.
This function allow a component inside the asyncio loop to reload its This function allow a component inside the asyncio loop to reload its
@ -284,23 +285,26 @@ async def async_hass_config_yaml(hass):
This method is a coroutine. This method is a coroutine.
""" """
def _load_hass_yaml_config(): def _load_hass_yaml_config() -> Dict:
path = find_config_file(hass.config.config_dir) path = find_config_file(hass.config.config_dir)
conf = load_yaml_config_file(path) if path is None:
return conf raise HomeAssistantError(
"Config file not found in: {}".format(hass.config.config_dir))
return load_yaml_config_file(path)
conf = await hass.async_add_job(_load_hass_yaml_config) return await hass.async_add_executor_job(_load_hass_yaml_config)
return conf
def find_config_file(config_dir: str) -> Optional[str]: def find_config_file(config_dir: Optional[str]) -> Optional[str]:
"""Look in given directory for supported configuration files.""" """Look in given directory for supported configuration files."""
if config_dir is None:
return None
config_path = os.path.join(config_dir, YAML_CONFIG_FILE) config_path = os.path.join(config_dir, YAML_CONFIG_FILE)
return config_path if os.path.isfile(config_path) else None return config_path if os.path.isfile(config_path) else None
def load_yaml_config_file(config_path): def load_yaml_config_file(config_path: str) -> Dict[Any, Any]:
"""Parse a YAML configuration file. """Parse a YAML configuration file.
This method needs to run in an executor. This method needs to run in an executor.
@ -323,7 +327,7 @@ def load_yaml_config_file(config_path):
return conf_dict return conf_dict
def process_ha_config_upgrade(hass): def process_ha_config_upgrade(hass: HomeAssistant) -> None:
"""Upgrade configuration if necessary. """Upgrade configuration if necessary.
This method needs to run in an executor. This method needs to run in an executor.
@ -360,7 +364,8 @@ def process_ha_config_upgrade(hass):
@callback @callback
def async_log_exception(ex, domain, config, hass): def async_log_exception(ex: vol.Invalid, domain: str, config: Dict,
hass: HomeAssistant) -> None:
"""Log an error for configuration validation. """Log an error for configuration validation.
This method must be run in the event loop. This method must be run in the event loop.
@ -371,7 +376,7 @@ def async_log_exception(ex, domain, config, hass):
@callback @callback
def _format_config_error(ex, domain, config): def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str:
"""Generate log exception for configuration validation. """Generate log exception for configuration validation.
This method must be run in the event loop. This method must be run in the event loop.
@ -396,7 +401,8 @@ def _format_config_error(ex, domain, config):
return message return message
async def async_process_ha_core_config(hass, config): async def async_process_ha_core_config(
hass: HomeAssistant, config: Dict) -> None:
"""Process the [homeassistant] section from the configuration. """Process the [homeassistant] section from the configuration.
This method is a coroutine. This method is a coroutine.
@ -405,12 +411,12 @@ async def async_process_ha_core_config(hass, config):
# Only load auth during startup. # Only load auth during startup.
if not hasattr(hass, 'auth'): if not hasattr(hass, 'auth'):
hass.auth = await auth.auth_manager_from_config( setattr(hass, 'auth', await auth.auth_manager_from_config(
hass, config.get(CONF_AUTH_PROVIDERS, [])) hass, config.get(CONF_AUTH_PROVIDERS, [])))
hac = hass.config hac = hass.config
def set_time_zone(time_zone_str): def set_time_zone(time_zone_str: Optional[str]) -> None:
"""Help to set the time zone.""" """Help to set the time zone."""
if time_zone_str is None: if time_zone_str is None:
return return
@ -430,11 +436,10 @@ async def async_process_ha_core_config(hass, config):
if key in config: if key in config:
setattr(hac, attr, config[key]) setattr(hac, attr, config[key])
if CONF_TIME_ZONE in config:
set_time_zone(config.get(CONF_TIME_ZONE)) set_time_zone(config.get(CONF_TIME_ZONE))
# Init whitelist external dir # Init whitelist external dir
hac.whitelist_external_dirs = set((hass.config.path('www'),)) hac.whitelist_external_dirs = {hass.config.path('www')}
if CONF_WHITELIST_EXTERNAL_DIRS in config: if CONF_WHITELIST_EXTERNAL_DIRS in config:
hac.whitelist_external_dirs.update( hac.whitelist_external_dirs.update(
set(config[CONF_WHITELIST_EXTERNAL_DIRS])) set(config[CONF_WHITELIST_EXTERNAL_DIRS]))
@ -484,12 +489,12 @@ async def async_process_ha_core_config(hass, config):
hac.time_zone, hac.elevation): hac.time_zone, hac.elevation):
return return
discovered = [] discovered = [] # type: List[Tuple[str, Any]]
# If we miss some of the needed values, auto detect them # If we miss some of the needed values, auto detect them
if None in (hac.latitude, hac.longitude, hac.units, if None in (hac.latitude, hac.longitude, hac.units,
hac.time_zone): hac.time_zone):
info = await hass.async_add_job( info = await hass.async_add_executor_job(
loc_util.detect_location_info) loc_util.detect_location_info)
if info is None: if info is None:
@ -515,7 +520,7 @@ async def async_process_ha_core_config(hass, config):
if hac.elevation is None and hac.latitude is not None and \ if hac.elevation is None and hac.latitude is not None and \
hac.longitude is not None: hac.longitude is not None:
elevation = await hass.async_add_job( elevation = await hass.async_add_executor_job(
loc_util.elevation, hac.latitude, hac.longitude) loc_util.elevation, hac.latitude, hac.longitude)
hac.elevation = elevation hac.elevation = elevation
discovered.append(('elevation', elevation)) discovered.append(('elevation', elevation))
@ -526,7 +531,8 @@ async def async_process_ha_core_config(hass, config):
", ".join('{}: {}'.format(key, val) for key, val in discovered)) ", ".join('{}: {}'.format(key, val) for key, val in discovered))
def _log_pkg_error(package, component, config, message): def _log_pkg_error(
package: str, component: str, config: Dict, message: str) -> None:
"""Log an error while merging packages.""" """Log an error while merging packages."""
message = "Package {} setup failed. Component {} {}".format( message = "Package {} setup failed. Component {} {}".format(
package, component, message) package, component, message)
@ -539,12 +545,13 @@ def _log_pkg_error(package, component, config, message):
_LOGGER.error(message) _LOGGER.error(message)
def _identify_config_schema(module): def _identify_config_schema(module: ModuleType) -> \
Tuple[Optional[str], Optional[Dict]]:
"""Extract the schema and identify list or dict based.""" """Extract the schema and identify list or dict based."""
try: try:
schema = module.CONFIG_SCHEMA.schema[module.DOMAIN] schema = module.CONFIG_SCHEMA.schema[module.DOMAIN] # type: ignore
except (AttributeError, KeyError): except (AttributeError, KeyError):
return (None, None) return None, None
t_schema = str(schema) t_schema = str(schema)
if t_schema.startswith('{'): if t_schema.startswith('{'):
return ('dict', schema) return ('dict', schema)
@ -553,9 +560,10 @@ def _identify_config_schema(module):
return '', schema return '', schema
def _recursive_merge(conf, package): def _recursive_merge(
conf: Dict[str, Any], package: Dict[str, Any]) -> Union[bool, str]:
"""Merge package into conf, recursively.""" """Merge package into conf, recursively."""
error = False error = False # type: Union[bool, str]
for key, pack_conf in package.items(): for key, pack_conf in package.items():
if isinstance(pack_conf, dict): if isinstance(pack_conf, dict):
if not pack_conf: if not pack_conf:
@ -576,8 +584,8 @@ def _recursive_merge(conf, package):
return error return error
def merge_packages_config(hass, config, packages, def merge_packages_config(hass: HomeAssistant, config: Dict, packages: Dict,
_log_pkg_error=_log_pkg_error): _log_pkg_error: Callable = _log_pkg_error) -> Dict:
"""Merge packages into the top-level configuration. Mutate config.""" """Merge packages into the top-level configuration. Mutate config."""
# pylint: disable=too-many-nested-blocks # pylint: disable=too-many-nested-blocks
PACKAGES_CONFIG_SCHEMA(packages) PACKAGES_CONFIG_SCHEMA(packages)
@ -641,7 +649,8 @@ def merge_packages_config(hass, config, packages,
@callback @callback
def async_process_component_config(hass, config, domain): def async_process_component_config(
hass: HomeAssistant, config: Dict, domain: str) -> Optional[Dict]:
"""Check component configuration and return processed configuration. """Check component configuration and return processed configuration.
Returns None on error. Returns None on error.
@ -703,14 +712,14 @@ def async_process_component_config(hass, config, domain):
return config return config
async def async_check_ha_config_file(hass): async def async_check_ha_config_file(hass: HomeAssistant) -> Optional[str]:
"""Check if Home Assistant configuration file is valid. """Check if Home Assistant configuration file is valid.
This method is a coroutine. This method is a coroutine.
""" """
from homeassistant.scripts.check_config import check_ha_config_file from homeassistant.scripts.check_config import check_ha_config_file
res = await hass.async_add_job( res = await hass.async_add_executor_job(
check_ha_config_file, hass) check_ha_config_file, hass)
if not res.errors: if not res.errors:
@ -719,7 +728,9 @@ async def async_check_ha_config_file(hass):
@callback @callback
def async_notify_setup_error(hass, component, display_link=False): def async_notify_setup_error(
hass: HomeAssistant, component: str,
display_link: bool = False) -> None:
"""Print a persistent notification. """Print a persistent notification.
This method must be run in the event loop. This method must be run in the event loop.

View file

@ -113,10 +113,10 @@ the flow from the config panel.
import logging import logging
import uuid import uuid
from typing import Set # noqa pylint: disable=unused-import from typing import Set, Optional # noqa pylint: disable=unused-import
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.core import callback from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component, async_process_deps_reqs from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
@ -164,8 +164,9 @@ class ConfigEntry:
__slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source', __slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source',
'state') 'state')
def __init__(self, version, domain, title, data, source, entry_id=None, def __init__(self, version: str, domain: str, title: str, data: dict,
state=ENTRY_STATE_NOT_LOADED): source: str, entry_id: Optional[str] = None,
state: str = ENTRY_STATE_NOT_LOADED) -> None:
"""Initialize a config entry.""" """Initialize a config entry."""
# Unique id of the config entry # Unique id of the config entry
self.entry_id = entry_id or uuid.uuid4().hex self.entry_id = entry_id or uuid.uuid4().hex
@ -188,7 +189,8 @@ class ConfigEntry:
# State of the entry (LOADED, NOT_LOADED) # State of the entry (LOADED, NOT_LOADED)
self.state = state self.state = state
async def async_setup(self, hass, *, component=None): async def async_setup(
self, hass: HomeAssistant, *, component=None) -> None:
"""Set up an entry.""" """Set up an entry."""
if component is None: if component is None:
component = getattr(hass.components, self.domain) component = getattr(hass.components, self.domain)

View file

@ -4,9 +4,9 @@ Core components of Home Assistant.
Home Assistant is a Home Automation framework for observing the state Home Assistant is a Home Automation framework for observing the state
of entities and react to changes. of entities and react to changes.
""" """
# pylint: disable=unused-import
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import datetime
import enum import enum
import logging import logging
import os import os
@ -17,9 +17,10 @@ import threading
from time import monotonic from time import monotonic
from types import MappingProxyType from types import MappingProxyType
# pylint: disable=unused-import
from typing import ( # NOQA from typing import ( # NOQA
Optional, Any, Callable, List, TypeVar, Dict, Coroutine, Set, Optional, Any, Callable, List, TypeVar, Dict, Coroutine, Set,
TYPE_CHECKING) TYPE_CHECKING, Awaitable, Iterator)
from async_timeout import timeout from async_timeout import timeout
import voluptuous as vol import voluptuous as vol
@ -44,11 +45,13 @@ from homeassistant.util import location
from homeassistant.util.unit_system import UnitSystem, METRIC_SYSTEM # NOQA from homeassistant.util.unit_system import UnitSystem, METRIC_SYSTEM # NOQA
# Typing imports that create a circular dependency # Typing imports that create a circular dependency
# pylint: disable=using-constant-test,unused-import # pylint: disable=using-constant-test
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntries # noqa from homeassistant.config_entries import ConfigEntries # noqa
T = TypeVar('T') T = TypeVar('T')
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
CALLBACK_TYPE = Callable[[], None]
DOMAIN = 'homeassistant' DOMAIN = 'homeassistant'
@ -79,7 +82,7 @@ def valid_state(state: str) -> bool:
return len(state) < 256 return len(state) < 256
def callback(func: Callable[..., T]) -> Callable[..., T]: def callback(func: CALLABLE_T) -> CALLABLE_T:
"""Annotation to mark method as safe to call from within the event loop.""" """Annotation to mark method as safe to call from within the event loop."""
setattr(func, '_hass_callback', True) setattr(func, '_hass_callback', True)
return func return func
@ -91,7 +94,7 @@ def is_callback(func: Callable[..., Any]) -> bool:
@callback @callback
def async_loop_exception_handler(loop, context): def async_loop_exception_handler(_: Any, context: Dict) -> None:
"""Handle all exception inside the core loop.""" """Handle all exception inside the core loop."""
kwargs = {} kwargs = {}
exception = context.get('exception') exception = context.get('exception')
@ -119,7 +122,9 @@ class CoreState(enum.Enum):
class HomeAssistant: class HomeAssistant:
"""Root object of the Home Assistant home automation.""" """Root object of the Home Assistant home automation."""
def __init__(self, loop=None): def __init__(
self,
loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None:
"""Initialize new Home Assistant object.""" """Initialize new Home Assistant object."""
if sys.platform == 'win32': if sys.platform == 'win32':
self.loop = loop or asyncio.ProactorEventLoop() self.loop = loop or asyncio.ProactorEventLoop()
@ -170,7 +175,7 @@ class HomeAssistant:
self.loop.close() self.loop.close()
return self.exit_code return self.exit_code
async def async_start(self): async def async_start(self) -> None:
"""Finalize startup from inside the event loop. """Finalize startup from inside the event loop.
This method is a coroutine. This method is a coroutine.
@ -178,8 +183,7 @@ class HomeAssistant:
_LOGGER.info("Starting Home Assistant") _LOGGER.info("Starting Home Assistant")
self.state = CoreState.starting self.state = CoreState.starting
# pylint: disable=protected-access setattr(self.loop, '_thread_ident', threading.get_ident())
self.loop._thread_ident = threading.get_ident()
self.bus.async_fire(EVENT_HOMEASSISTANT_START) self.bus.async_fire(EVENT_HOMEASSISTANT_START)
try: try:
@ -230,7 +234,8 @@ class HomeAssistant:
elif asyncio.iscoroutinefunction(target): elif asyncio.iscoroutinefunction(target):
task = self.loop.create_task(target(*args)) task = self.loop.create_task(target(*args))
else: else:
task = self.loop.run_in_executor(None, target, *args) task = self.loop.run_in_executor( # type: ignore
None, target, *args)
# If a task is scheduled # If a task is scheduled
if self._track_task and task is not None: if self._track_task and task is not None:
@ -256,11 +261,11 @@ class HomeAssistant:
@callback @callback
def async_add_executor_job( def async_add_executor_job(
self, self,
target: Callable[..., Any], target: Callable[..., T],
*args: Any) -> asyncio.Future: *args: Any) -> Awaitable[T]:
"""Add an executor job from within the event loop.""" """Add an executor job from within the event loop."""
task = self.loop.run_in_executor( # type: ignore task = self.loop.run_in_executor(
None, target, *args) # type: asyncio.Future None, target, *args)
# If a task is scheduled # If a task is scheduled
if self._track_task: if self._track_task:
@ -269,12 +274,12 @@ class HomeAssistant:
return task return task
@callback @callback
def async_track_tasks(self): def async_track_tasks(self) -> None:
"""Track tasks so you can wait for all tasks to be done.""" """Track tasks so you can wait for all tasks to be done."""
self._track_task = True self._track_task = True
@callback @callback
def async_stop_track_tasks(self): def async_stop_track_tasks(self) -> None:
"""Stop track tasks so you can't wait for all tasks to be done.""" """Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False self._track_task = False
@ -297,7 +302,7 @@ class HomeAssistant:
run_coroutine_threadsafe( run_coroutine_threadsafe(
self.async_block_till_done(), loop=self.loop).result() self.async_block_till_done(), loop=self.loop).result()
async def async_block_till_done(self): async def async_block_till_done(self) -> None:
"""Block till all pending work is done.""" """Block till all pending work is done."""
# To flush out any call_soon_threadsafe # To flush out any call_soon_threadsafe
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
@ -342,9 +347,9 @@ class EventOrigin(enum.Enum):
local = 'LOCAL' local = 'LOCAL'
remote = 'REMOTE' remote = 'REMOTE'
def __str__(self): def __str__(self) -> str:
"""Return the event.""" """Return the event."""
return self.value return self.value # type: ignore
class Event: class Event:
@ -352,15 +357,16 @@ class Event:
__slots__ = ['event_type', 'data', 'origin', 'time_fired'] __slots__ = ['event_type', 'data', 'origin', 'time_fired']
def __init__(self, event_type, data=None, origin=EventOrigin.local, def __init__(self, event_type: str, data: Optional[Dict] = None,
time_fired=None): origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None) -> None:
"""Initialize a new event.""" """Initialize a new event."""
self.event_type = event_type self.event_type = event_type
self.data = data or {} self.data = data or {}
self.origin = origin self.origin = origin
self.time_fired = time_fired or dt_util.utcnow() self.time_fired = time_fired or dt_util.utcnow()
def as_dict(self): def as_dict(self) -> Dict:
"""Create a dict representation of this Event. """Create a dict representation of this Event.
Async friendly. Async friendly.
@ -372,7 +378,7 @@ class Event:
'time_fired': self.time_fired, 'time_fired': self.time_fired,
} }
def __repr__(self): def __repr__(self) -> str:
"""Return the representation.""" """Return the representation."""
# pylint: disable=maybe-no-member # pylint: disable=maybe-no-member
if self.data: if self.data:
@ -383,9 +389,9 @@ class Event:
return "<Event {}[{}]>".format(self.event_type, return "<Event {}[{}]>".format(self.event_type,
str(self.origin)[0]) str(self.origin)[0])
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
"""Return the comparison.""" """Return the comparison."""
return (self.__class__ == other.__class__ and return (self.__class__ == other.__class__ and # type: ignore
self.event_type == other.event_type and self.event_type == other.event_type and
self.data == other.data and self.data == other.data and
self.origin == other.origin and self.origin == other.origin and
@ -401,7 +407,7 @@ class EventBus:
self._hass = hass self._hass = hass
@callback @callback
def async_listeners(self): def async_listeners(self) -> Dict[str, int]:
"""Return dictionary with events and the number of listeners. """Return dictionary with events and the number of listeners.
This method must be run in the event loop. This method must be run in the event loop.
@ -410,20 +416,21 @@ class EventBus:
for key in self._listeners} for key in self._listeners}
@property @property
def listeners(self): def listeners(self) -> Dict[str, int]:
"""Return dictionary with events and the number of listeners.""" """Return dictionary with events and the number of listeners."""
return run_callback_threadsafe( return run_callback_threadsafe( # type: ignore
self._hass.loop, self.async_listeners self._hass.loop, self.async_listeners
).result() ).result()
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local): def fire(self, event_type: str, event_data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local) -> None:
"""Fire an event.""" """Fire an event."""
self._hass.loop.call_soon_threadsafe( self._hass.loop.call_soon_threadsafe(
self.async_fire, event_type, event_data, origin) self.async_fire, event_type, event_data, origin)
@callback @callback
def async_fire(self, event_type: str, event_data=None, def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
origin=EventOrigin.local): origin: EventOrigin = EventOrigin.local) -> None:
"""Fire an event. """Fire an event.
This method must be run in the event loop. This method must be run in the event loop.
@ -447,7 +454,8 @@ class EventBus:
for func in listeners: for func in listeners:
self._hass.async_add_job(func, event) self._hass.async_add_job(func, event)
def listen(self, event_type, listener): def listen(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type. """Listen for all events or events of a specific type.
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``
@ -456,7 +464,7 @@ class EventBus:
async_remove_listener = run_callback_threadsafe( async_remove_listener = run_callback_threadsafe(
self._hass.loop, self.async_listen, event_type, listener).result() self._hass.loop, self.async_listen, event_type, listener).result()
def remove_listener(): def remove_listener() -> None:
"""Remove the listener.""" """Remove the listener."""
run_callback_threadsafe( run_callback_threadsafe(
self._hass.loop, async_remove_listener).result() self._hass.loop, async_remove_listener).result()
@ -464,7 +472,8 @@ class EventBus:
return remove_listener return remove_listener
@callback @callback
def async_listen(self, event_type, listener): def async_listen(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type. """Listen for all events or events of a specific type.
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``
@ -477,13 +486,14 @@ class EventBus:
else: else:
self._listeners[event_type] = [listener] self._listeners[event_type] = [listener]
def remove_listener(): def remove_listener() -> None:
"""Remove the listener.""" """Remove the listener."""
self._async_remove_listener(event_type, listener) self._async_remove_listener(event_type, listener)
return remove_listener return remove_listener
def listen_once(self, event_type, listener): def listen_once(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen once for event of a specific type. """Listen once for event of a specific type.
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``
@ -495,7 +505,7 @@ class EventBus:
self._hass.loop, self.async_listen_once, event_type, listener, self._hass.loop, self.async_listen_once, event_type, listener,
).result() ).result()
def remove_listener(): def remove_listener() -> None:
"""Remove the listener.""" """Remove the listener."""
run_callback_threadsafe( run_callback_threadsafe(
self._hass.loop, async_remove_listener).result() self._hass.loop, async_remove_listener).result()
@ -503,7 +513,8 @@ class EventBus:
return remove_listener return remove_listener
@callback @callback
def async_listen_once(self, event_type, listener): def async_listen_once(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen once for event of a specific type. """Listen once for event of a specific type.
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``
@ -514,7 +525,7 @@ class EventBus:
This method must be run in the event loop. This method must be run in the event loop.
""" """
@callback @callback
def onetime_listener(event): def onetime_listener(event: Event) -> None:
"""Remove listener from event bus and then fire listener.""" """Remove listener from event bus and then fire listener."""
if hasattr(onetime_listener, 'run'): if hasattr(onetime_listener, 'run'):
return return
@ -530,7 +541,8 @@ class EventBus:
return self.async_listen(event_type, onetime_listener) return self.async_listen(event_type, onetime_listener)
@callback @callback
def _async_remove_listener(self, event_type, listener): def _async_remove_listener(
self, event_type: str, listener: Callable) -> None:
"""Remove a listener of a specific event_type. """Remove a listener of a specific event_type.
This method must be run in the event loop. This method must be run in the event loop.
@ -560,8 +572,10 @@ class State:
__slots__ = ['entity_id', 'state', 'attributes', __slots__ = ['entity_id', 'state', 'attributes',
'last_changed', 'last_updated'] 'last_changed', 'last_updated']
def __init__(self, entity_id, state, attributes=None, last_changed=None, def __init__(self, entity_id: str, state: Any,
last_updated=None): attributes: Optional[Dict] = None,
last_changed: Optional[datetime.datetime] = None,
last_updated: Optional[datetime.datetime] = None) -> None:
"""Initialize a new state.""" """Initialize a new state."""
state = str(state) state = str(state)
@ -582,23 +596,23 @@ class State:
self.last_changed = last_changed or self.last_updated self.last_changed = last_changed or self.last_updated
@property @property
def domain(self): def domain(self) -> str:
"""Domain of this state.""" """Domain of this state."""
return split_entity_id(self.entity_id)[0] return split_entity_id(self.entity_id)[0]
@property @property
def object_id(self): def object_id(self) -> str:
"""Object id of this state.""" """Object id of this state."""
return split_entity_id(self.entity_id)[1] return split_entity_id(self.entity_id)[1]
@property @property
def name(self): def name(self) -> str:
"""Name of this state.""" """Name of this state."""
return ( return (
self.attributes.get(ATTR_FRIENDLY_NAME) or self.attributes.get(ATTR_FRIENDLY_NAME) or
self.object_id.replace('_', ' ')) self.object_id.replace('_', ' '))
def as_dict(self): def as_dict(self) -> Dict:
"""Return a dict representation of the State. """Return a dict representation of the State.
Async friendly. Async friendly.
@ -613,7 +627,7 @@ class State:
'last_updated': self.last_updated} 'last_updated': self.last_updated}
@classmethod @classmethod
def from_dict(cls, json_dict): def from_dict(cls, json_dict: Dict) -> Any:
"""Initialize a state from a dict. """Initialize a state from a dict.
Async friendly. Async friendly.
@ -637,14 +651,14 @@ class State:
return cls(json_dict['entity_id'], json_dict['state'], return cls(json_dict['entity_id'], json_dict['state'],
json_dict.get('attributes'), last_changed, last_updated) json_dict.get('attributes'), last_changed, last_updated)
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
"""Return the comparison of the state.""" """Return the comparison of the state."""
return (self.__class__ == other.__class__ and return (self.__class__ == other.__class__ and # type: ignore
self.entity_id == other.entity_id and self.entity_id == other.entity_id and
self.state == other.state and self.state == other.state and
self.attributes == other.attributes) self.attributes == other.attributes)
def __repr__(self): def __repr__(self) -> str:
"""Return the representation of the states.""" """Return the representation of the states."""
attr = "; {}".format(util.repr_helper(self.attributes)) \ attr = "; {}".format(util.repr_helper(self.attributes)) \
if self.attributes else "" if self.attributes else ""
@ -657,21 +671,23 @@ class State:
class StateMachine: class StateMachine:
"""Helper class that tracks the state of different entities.""" """Helper class that tracks the state of different entities."""
def __init__(self, bus, loop): def __init__(self, bus: EventBus,
loop: asyncio.events.AbstractEventLoop) -> None:
"""Initialize state machine.""" """Initialize state machine."""
self._states = {} # type: Dict[str, State] self._states = {} # type: Dict[str, State]
self._bus = bus self._bus = bus
self._loop = loop self._loop = loop
def entity_ids(self, domain_filter=None): def entity_ids(self, domain_filter: Optional[str] = None)-> List[str]:
"""List of entity ids that are being tracked.""" """List of entity ids that are being tracked."""
future = run_callback_threadsafe( future = run_callback_threadsafe(
self._loop, self.async_entity_ids, domain_filter self._loop, self.async_entity_ids, domain_filter
) )
return future.result() return future.result() # type: ignore
@callback @callback
def async_entity_ids(self, domain_filter=None): def async_entity_ids(
self, domain_filter: Optional[str] = None) -> List[str]:
"""List of entity ids that are being tracked. """List of entity ids that are being tracked.
This method must be run in the event loop. This method must be run in the event loop.
@ -684,26 +700,27 @@ class StateMachine:
return [state.entity_id for state in self._states.values() return [state.entity_id for state in self._states.values()
if state.domain == domain_filter] if state.domain == domain_filter]
def all(self): def all(self)-> List[State]:
"""Create a list of all states.""" """Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result() return run_callback_threadsafe( # type: ignore
self._loop, self.async_all).result()
@callback @callback
def async_all(self): def async_all(self)-> List[State]:
"""Create a list of all states. """Create a list of all states.
This method must be run in the event loop. This method must be run in the event loop.
""" """
return list(self._states.values()) return list(self._states.values())
def get(self, entity_id): def get(self, entity_id: str) -> Optional[State]:
"""Retrieve state of entity_id or None if not found. """Retrieve state of entity_id or None if not found.
Async friendly. Async friendly.
""" """
return self._states.get(entity_id.lower()) return self._states.get(entity_id.lower())
def is_state(self, entity_id, state): def is_state(self, entity_id: str, state: State) -> bool:
"""Test if entity exists and is specified state. """Test if entity exists and is specified state.
Async friendly. Async friendly.
@ -711,16 +728,16 @@ class StateMachine:
state_obj = self.get(entity_id) state_obj = self.get(entity_id)
return state_obj is not None and state_obj.state == state return state_obj is not None and state_obj.state == state
def remove(self, entity_id): def remove(self, entity_id: str) -> bool:
"""Remove the state of an entity. """Remove the state of an entity.
Returns boolean to indicate if an entity was removed. Returns boolean to indicate if an entity was removed.
""" """
return run_callback_threadsafe( return run_callback_threadsafe( # type: ignore
self._loop, self.async_remove, entity_id).result() self._loop, self.async_remove, entity_id).result()
@callback @callback
def async_remove(self, entity_id): def async_remove(self, entity_id: str) -> bool:
"""Remove the state of an entity. """Remove the state of an entity.
Returns boolean to indicate if an entity was removed. Returns boolean to indicate if an entity was removed.
@ -740,7 +757,9 @@ class StateMachine:
}) })
return True return True
def set(self, entity_id, new_state, attributes=None, force_update=False): def set(self, entity_id: str, new_state: Any,
attributes: Optional[Dict] = None,
force_update: bool = False) -> None:
"""Set the state of an entity, add entity if it does not exist. """Set the state of an entity, add entity if it does not exist.
Attributes is an optional dict to specify attributes of this state. Attributes is an optional dict to specify attributes of this state.
@ -754,8 +773,9 @@ class StateMachine:
).result() ).result()
@callback @callback
def async_set(self, entity_id, new_state, attributes=None, def async_set(self, entity_id: str, new_state: Any,
force_update=False): attributes: Optional[Dict] = None,
force_update: bool = False) -> None:
"""Set the state of an entity, add entity if it does not exist. """Set the state of an entity, add entity if it does not exist.
Attributes is an optional dict to specify attributes of this state. Attributes is an optional dict to specify attributes of this state.
@ -769,15 +789,19 @@ class StateMachine:
new_state = str(new_state) new_state = str(new_state)
attributes = attributes or {} attributes = attributes or {}
old_state = self._states.get(entity_id) old_state = self._states.get(entity_id)
is_existing = old_state is not None if old_state is None:
same_state = (is_existing and old_state.state == new_state and same_state = False
same_attr = False
last_changed = None
else:
same_state = (old_state.state == new_state and
not force_update) not force_update)
same_attr = is_existing and old_state.attributes == attributes same_attr = old_state.attributes == attributes
last_changed = old_state.last_changed if same_state else None
if same_state and same_attr: if same_state and same_attr:
return return
last_changed = old_state.last_changed if same_state else None
state = State(entity_id, new_state, attributes, last_changed) state = State(entity_id, new_state, attributes, last_changed)
self._states[entity_id] = state self._states[entity_id] = state
self._bus.async_fire(EVENT_STATE_CHANGED, { self._bus.async_fire(EVENT_STATE_CHANGED, {
@ -792,7 +816,7 @@ class Service:
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction'] __slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
def __init__(self, func, schema): def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None:
"""Initialize a service.""" """Initialize a service."""
self.func = func self.func = func
self.schema = schema self.schema = schema
@ -805,14 +829,15 @@ class ServiceCall:
__slots__ = ['domain', 'service', 'data', 'call_id'] __slots__ = ['domain', 'service', 'data', 'call_id']
def __init__(self, domain, service, data=None, call_id=None): def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
call_id: Optional[str] = None) -> None:
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain.lower() self.domain = domain.lower()
self.service = service.lower() self.service = service.lower()
self.data = MappingProxyType(data or {}) self.data = MappingProxyType(data or {})
self.call_id = call_id self.call_id = call_id
def __repr__(self): def __repr__(self) -> str:
"""Return the representation of the service.""" """Return the representation of the service."""
if self.data: if self.data:
return "<ServiceCall {}.{}: {}>".format( return "<ServiceCall {}.{}: {}>".format(
@ -824,13 +849,13 @@ class ServiceCall:
class ServiceRegistry: class ServiceRegistry:
"""Offer the services over the eventbus.""" """Offer the services over the eventbus."""
def __init__(self, hass): def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a service registry.""" """Initialize a service registry."""
self._services = {} # type: Dict[str, Dict[str, Service]] self._services = {} # type: Dict[str, Dict[str, Service]]
self._hass = hass self._hass = hass
self._async_unsub_call_event = None self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
def _gen_unique_id(): def _gen_unique_id() -> Iterator[str]:
cur_id = 1 cur_id = 1
while True: while True:
yield '{}-{}'.format(id(self), cur_id) yield '{}-{}'.format(id(self), cur_id)
@ -840,14 +865,14 @@ class ServiceRegistry:
self._generate_unique_id = lambda: next(gen) self._generate_unique_id = lambda: next(gen)
@property @property
def services(self): def services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services.""" """Return dictionary with per domain a list of available services."""
return run_callback_threadsafe( return run_callback_threadsafe( # type: ignore
self._hass.loop, self.async_services, self._hass.loop, self.async_services,
).result() ).result()
@callback @callback
def async_services(self): def async_services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services. """Return dictionary with per domain a list of available services.
This method must be run in the event loop. This method must be run in the event loop.
@ -855,14 +880,15 @@ class ServiceRegistry:
return {domain: self._services[domain].copy() return {domain: self._services[domain].copy()
for domain in self._services} for domain in self._services}
def has_service(self, domain, service): def has_service(self, domain: str, service: str) -> bool:
"""Test if specified service exists. """Test if specified service exists.
Async friendly. Async friendly.
""" """
return service.lower() in self._services.get(domain.lower(), []) return service.lower() in self._services.get(domain.lower(), [])
def register(self, domain, service, service_func, schema=None): def register(self, domain: str, service: str, service_func: Callable,
schema: Optional[vol.Schema] = None) -> None:
""" """
Register a service. Register a service.
@ -874,7 +900,8 @@ class ServiceRegistry:
).result() ).result()
@callback @callback
def async_register(self, domain, service, service_func, schema=None): def async_register(self, domain: str, service: str, service_func: Callable,
schema: Optional[vol.Schema] = None) -> None:
""" """
Register a service. Register a service.
@ -900,13 +927,13 @@ class ServiceRegistry:
{ATTR_DOMAIN: domain, ATTR_SERVICE: service} {ATTR_DOMAIN: domain, ATTR_SERVICE: service}
) )
def remove(self, domain, service): def remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler.""" """Remove a registered service from service handler."""
run_callback_threadsafe( run_callback_threadsafe(
self._hass.loop, self.async_remove, domain, service).result() self._hass.loop, self.async_remove, domain, service).result()
@callback @callback
def async_remove(self, domain, service): def async_remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler. """Remove a registered service from service handler.
This method must be run in the event loop. This method must be run in the event loop.
@ -926,7 +953,9 @@ class ServiceRegistry:
{ATTR_DOMAIN: domain, ATTR_SERVICE: service} {ATTR_DOMAIN: domain, ATTR_SERVICE: service}
) )
def call(self, domain, service, service_data=None, blocking=False): def call(self, domain: str, service: str,
service_data: Optional[Dict] = None,
blocking: bool = False) -> Optional[bool]:
""" """
Call a service. Call a service.
@ -943,13 +972,14 @@ class ServiceRegistry:
Because the service is sent as an event you are not allowed to use Because the service is sent as an event you are not allowed to use
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
""" """
return run_coroutine_threadsafe( return run_coroutine_threadsafe( # type: ignore
self.async_call(domain, service, service_data, blocking), self.async_call(domain, service, service_data, blocking),
self._hass.loop self._hass.loop
).result() ).result()
async def async_call(self, domain, service, service_data=None, async def async_call(self, domain: str, service: str,
blocking=False): service_data: Optional[Dict] = None,
blocking: bool = False) -> Optional[bool]:
""" """
Call a service. Call a service.
@ -981,7 +1011,7 @@ class ServiceRegistry:
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future
@callback @callback
def service_executed(event): def service_executed(event: Event) -> None:
"""Handle an executed service.""" """Handle an executed service."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id: if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True) fut.set_result(True)
@ -991,18 +1021,20 @@ class ServiceRegistry:
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
if blocking:
done, _ = await asyncio.wait( done, _ = await asyncio.wait(
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT) [fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
success = bool(done) success = bool(done)
unsub() unsub()
return success return success
async def _event_to_service_call(self, event): self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
return None
async def _event_to_service_call(self, event: Event) -> None:
"""Handle the SERVICE_CALLED events from the EventBus.""" """Handle the SERVICE_CALLED events from the EventBus."""
service_data = event.data.get(ATTR_SERVICE_DATA) or {} service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower() domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
service = event.data.get(ATTR_SERVICE).lower() service = event.data.get(ATTR_SERVICE).lower() # type: ignore
call_id = event.data.get(ATTR_SERVICE_CALL_ID) call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service): if not self.has_service(domain, service):
@ -1013,7 +1045,7 @@ class ServiceRegistry:
service_handler = self._services[domain][service] service_handler = self._services[domain][service]
def fire_service_executed(): def fire_service_executed() -> None:
"""Fire service executed event.""" """Fire service executed event."""
if not call_id: if not call_id:
return return
@ -1045,12 +1077,12 @@ class ServiceRegistry:
await service_handler.func(service_call) await service_handler.func(service_call)
fire_service_executed() fire_service_executed()
else: else:
def execute_service(): def execute_service() -> None:
"""Execute a service and fires a SERVICE_EXECUTED event.""" """Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call) service_handler.func(service_call)
fire_service_executed() fire_service_executed()
await self._hass.async_add_job(execute_service) await self._hass.async_add_executor_job(execute_service)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error executing service %s', service_call) _LOGGER.exception('Error executing service %s', service_call)
@ -1058,13 +1090,13 @@ class ServiceRegistry:
class Config: class Config:
"""Configuration settings for Home Assistant.""" """Configuration settings for Home Assistant."""
def __init__(self): def __init__(self) -> None:
"""Initialize a new config object.""" """Initialize a new config object."""
self.latitude = None # type: Optional[float] self.latitude = None # type: Optional[float]
self.longitude = None # type: Optional[float] self.longitude = None # type: Optional[float]
self.elevation = None # type: Optional[int] self.elevation = None # type: Optional[int]
self.location_name = None # type: Optional[str] self.location_name = None # type: Optional[str]
self.time_zone = None # type: Optional[str] self.time_zone = None # type: Optional[datetime.tzinfo]
self.units = METRIC_SYSTEM # type: UnitSystem self.units = METRIC_SYSTEM # type: UnitSystem
# If True, pip install is skipped for requirements on startup # If True, pip install is skipped for requirements on startup
@ -1090,7 +1122,7 @@ class Config:
return self.units.length( return self.units.length(
location.distance(self.latitude, self.longitude, lat, lon), 'm') location.distance(self.latitude, self.longitude, lat, lon), 'm')
def path(self, *path): def path(self, *path: str) -> str:
"""Generate path to the file within the configuration directory. """Generate path to the file within the configuration directory.
Async friendly. Async friendly.
@ -1122,12 +1154,14 @@ class Config:
return False return False
def as_dict(self): def as_dict(self) -> Dict:
"""Create a dictionary representation of this dict. """Create a dictionary representation of this dict.
Async friendly. Async friendly.
""" """
time_zone = self.time_zone or dt_util.UTC time_zone = dt_util.UTC.zone
if self.time_zone and getattr(self.time_zone, 'zone'):
time_zone = getattr(self.time_zone, 'zone')
return { return {
'latitude': self.latitude, 'latitude': self.latitude,
@ -1135,7 +1169,7 @@ class Config:
'elevation': self.elevation, 'elevation': self.elevation,
'unit_system': self.units.as_dict(), 'unit_system': self.units.as_dict(),
'location_name': self.location_name, 'location_name': self.location_name,
'time_zone': time_zone.zone, 'time_zone': time_zone,
'components': self.components, 'components': self.components,
'config_dir': self.config_dir, 'config_dir': self.config_dir,
'whitelist_external_dirs': self.whitelist_external_dirs, 'whitelist_external_dirs': self.whitelist_external_dirs,
@ -1143,12 +1177,12 @@ class Config:
} }
def _async_create_timer(hass): def _async_create_timer(hass: HomeAssistant) -> None:
"""Create a timer that will start on HOMEASSISTANT_START.""" """Create a timer that will start on HOMEASSISTANT_START."""
handle = None handle = None
@callback @callback
def fire_time_event(nxt): def fire_time_event(nxt: float) -> None:
"""Fire next time event.""" """Fire next time event."""
nonlocal handle nonlocal handle
@ -1165,7 +1199,7 @@ def _async_create_timer(hass):
handle = hass.loop.call_later(slp_seconds, fire_time_event, nxt) handle = hass.loop.call_later(slp_seconds, fire_time_event, nxt)
@callback @callback
def stop_timer(event): def stop_timer(_: Event) -> None:
"""Stop the timer.""" """Stop the timer."""
if handle is not None: if handle is not None:
handle.cancel() handle.cancel()

View file

@ -1,8 +1,9 @@
"""Classes to help gather user submissions.""" """Classes to help gather user submissions."""
import logging import logging
import uuid import uuid
from typing import Dict, Any # noqa pylint: disable=unused-import import voluptuous as vol
from .core import callback from typing import Dict, Any, Callable, List, Optional # noqa pylint: disable=unused-import
from .core import callback, HomeAssistant
from .exceptions import HomeAssistantError from .exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -35,7 +36,8 @@ class UnknownStep(FlowError):
class FlowManager: class FlowManager:
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
def __init__(self, hass, async_create_flow, async_finish_flow): def __init__(self, hass: HomeAssistant, async_create_flow: Callable,
async_finish_flow: Callable) -> None:
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._progress = {} # type: Dict[str, Any] self._progress = {} # type: Dict[str, Any]
@ -43,7 +45,7 @@ class FlowManager:
self._async_finish_flow = async_finish_flow self._async_finish_flow = async_finish_flow
@callback @callback
def async_progress(self): def async_progress(self) -> List[Dict]:
"""Return the flows in progress.""" """Return the flows in progress."""
return [{ return [{
'flow_id': flow.flow_id, 'flow_id': flow.flow_id,
@ -51,7 +53,8 @@ class FlowManager:
'source': flow.source, 'source': flow.source,
} for flow in self._progress.values()] } for flow in self._progress.values()]
async def async_init(self, handler, *, source=SOURCE_USER, data=None): async def async_init(self, handler: Callable, *, source: str = SOURCE_USER,
data: str = None) -> Any:
"""Start a configuration flow.""" """Start a configuration flow."""
flow = await self._async_create_flow(handler, source=source, data=data) flow = await self._async_create_flow(handler, source=source, data=data)
flow.hass = self.hass flow.hass = self.hass
@ -67,7 +70,8 @@ class FlowManager:
return await self._async_handle_step(flow, step, data) return await self._async_handle_step(flow, step, data)
async def async_configure(self, flow_id, user_input=None): async def async_configure(
self, flow_id: str, user_input: str = None) -> Any:
"""Continue a configuration flow.""" """Continue a configuration flow."""
flow = self._progress.get(flow_id) flow = self._progress.get(flow_id)
@ -83,12 +87,13 @@ class FlowManager:
flow, step_id, user_input) flow, step_id, user_input)
@callback @callback
def async_abort(self, flow_id): def async_abort(self, flow_id: str) -> None:
"""Abort a flow.""" """Abort a flow."""
if self._progress.pop(flow_id, None) is None: if self._progress.pop(flow_id, None) is None:
raise UnknownFlow raise UnknownFlow
async def _async_handle_step(self, flow, step_id, user_input): async def _async_handle_step(self, flow: Any, step_id: str,
user_input: Optional[str]) -> Dict:
"""Handle a step of a flow.""" """Handle a step of a flow."""
method = "async_step_{}".format(step_id) method = "async_step_{}".format(step_id)
@ -97,7 +102,7 @@ class FlowManager:
raise UnknownStep("Handler {} doesn't support step {}".format( raise UnknownStep("Handler {} doesn't support step {}".format(
flow.__class__.__name__, step_id)) flow.__class__.__name__, step_id))
result = await getattr(flow, method)(user_input) result = await getattr(flow, method)(user_input) # type: Dict
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY, if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
RESULT_TYPE_ABORT): RESULT_TYPE_ABORT):
@ -133,8 +138,9 @@ class FlowHandler:
VERSION = 1 VERSION = 1
@callback @callback
def async_show_form(self, *, step_id, data_schema=None, errors=None, def async_show_form(self, *, step_id: str, data_schema: vol.Schema = None,
description_placeholders=None): errors: Dict = None,
description_placeholders: Dict = None) -> Dict:
"""Return the definition of a form to gather user input.""" """Return the definition of a form to gather user input."""
return { return {
'type': RESULT_TYPE_FORM, 'type': RESULT_TYPE_FORM,
@ -147,7 +153,7 @@ class FlowHandler:
} }
@callback @callback
def async_create_entry(self, *, title, data): def async_create_entry(self, *, title: str, data: Dict) -> Dict:
"""Finish config flow and create a config entry.""" """Finish config flow and create a config entry."""
return { return {
'version': self.VERSION, 'version': self.VERSION,
@ -160,7 +166,7 @@ class FlowHandler:
} }
@callback @callback
def async_abort(self, *, reason): def async_abort(self, *, reason: str) -> Dict:
"""Abort the config flow.""" """Abort the config flow."""
return { return {
'type': RESULT_TYPE_ABORT, 'type': RESULT_TYPE_ABORT,

View file

@ -17,7 +17,7 @@ import sys
from types import ModuleType from types import ModuleType
# pylint: disable=unused-import # pylint: disable=unused-import
from typing import Optional, Set, TYPE_CHECKING # NOQA from typing import Optional, Set, TYPE_CHECKING, Callable, Any, TypeVar # NOQA
from homeassistant.const import PLATFORM_FORMAT from homeassistant.const import PLATFORM_FORMAT
from homeassistant.util import OrderedSet from homeassistant.util import OrderedSet
@ -27,6 +27,8 @@ from homeassistant.util import OrderedSet
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.core import HomeAssistant # NOQA from homeassistant.core import HomeAssistant # NOQA
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
PREPARED = False PREPARED = False
DEPENDENCY_BLACKLIST = {'config'} DEPENDENCY_BLACKLIST = {'config'}
@ -51,7 +53,8 @@ def set_component(hass, # type: HomeAssistant
cache[comp_name] = component cache[comp_name] = component
def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]: def get_platform(hass, # type: HomeAssistant
domain: str, platform: str) -> Optional[ModuleType]:
"""Try to load specified platform. """Try to load specified platform.
Async friendly. Async friendly.
@ -59,7 +62,8 @@ def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]:
return get_component(hass, PLATFORM_FORMAT.format(domain, platform)) return get_component(hass, PLATFORM_FORMAT.format(domain, platform))
def get_component(hass, comp_or_platform) -> Optional[ModuleType]: def get_component(hass, # type: HomeAssistant
comp_or_platform: str) -> Optional[ModuleType]:
"""Try to load specified component. """Try to load specified component.
Looks in config dir first, then built-in components. Looks in config dir first, then built-in components.
@ -73,6 +77,9 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
cache = hass.data.get(DATA_KEY) cache = hass.data.get(DATA_KEY)
if cache is None: if cache is None:
if hass.config.config_dir is None:
_LOGGER.error("Can't load components - config dir is not set")
return None
# Only insert if it's not there (happens during tests) # Only insert if it's not there (happens during tests)
if sys.path[0] != hass.config.config_dir: if sys.path[0] != hass.config.config_dir:
sys.path.insert(0, hass.config.config_dir) sys.path.insert(0, hass.config.config_dir)
@ -134,14 +141,38 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
return None return None
class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument."""
def __init__(self,
hass, # type: HomeAssistant
module: ModuleType) -> None:
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
def __getattr__(self, attr: str) -> Any:
"""Fetch an attribute."""
value = getattr(self._module, attr)
if hasattr(value, '__bind_hass'):
value = ft.partial(value, self._hass)
setattr(self, attr, value)
return value
class Components: class Components:
"""Helper to load components.""" """Helper to load components."""
def __init__(self, hass): def __init__(
self,
hass # type: HomeAssistant
) -> None:
"""Initialize the Components class.""" """Initialize the Components class."""
self._hass = hass self._hass = hass
def __getattr__(self, comp_name): def __getattr__(self, comp_name: str) -> ModuleWrapper:
"""Fetch a component.""" """Fetch a component."""
component = get_component(self._hass, comp_name) component = get_component(self._hass, comp_name)
if component is None: if component is None:
@ -154,11 +185,14 @@ class Components:
class Helpers: class Helpers:
"""Helper to load helpers.""" """Helper to load helpers."""
def __init__(self, hass): def __init__(
self,
hass # type: HomeAssistant
) -> None:
"""Initialize the Helpers class.""" """Initialize the Helpers class."""
self._hass = hass self._hass = hass
def __getattr__(self, helper_name): def __getattr__(self, helper_name: str) -> ModuleWrapper:
"""Fetch a helper.""" """Fetch a helper."""
helper = importlib.import_module( helper = importlib.import_module(
'homeassistant.helpers.{}'.format(helper_name)) 'homeassistant.helpers.{}'.format(helper_name))
@ -167,33 +201,14 @@ class Helpers:
return wrapped return wrapped
class ModuleWrapper: def bind_hass(func: CALLABLE_T) -> CALLABLE_T:
"""Class to wrap a Python module and auto fill in hass argument."""
def __init__(self, hass, module):
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
def __getattr__(self, attr):
"""Fetch an attribute."""
value = getattr(self._module, attr)
if hasattr(value, '__bind_hass'):
value = ft.partial(value, self._hass)
setattr(self, attr, value)
return value
def bind_hass(func):
"""Decorate function to indicate that first argument is hass.""" """Decorate function to indicate that first argument is hass."""
# pylint: disable=protected-access setattr(func, '__bind_hass', True)
func.__bind_hass = True
return func return func
def load_order_component(hass, comp_name: str) -> OrderedSet: def load_order_component(hass, # type: HomeAssistant
comp_name: str) -> OrderedSet:
"""Return an OrderedSet of components in the correct order of loading. """Return an OrderedSet of components in the correct order of loading.
Raises HomeAssistantError if a circular dependency is detected. Raises HomeAssistantError if a circular dependency is detected.
@ -204,7 +219,8 @@ def load_order_component(hass, comp_name: str) -> OrderedSet:
return _load_order_component(hass, comp_name, OrderedSet(), set()) return _load_order_component(hass, comp_name, OrderedSet(), set())
def _load_order_component(hass, comp_name: str, load_order: OrderedSet, def _load_order_component(hass, # type: HomeAssistant
comp_name: str, load_order: OrderedSet,
loading: Set) -> OrderedSet: loading: Set) -> OrderedSet:
"""Recursive function to get load order of components. """Recursive function to get load order of components.

View file

@ -20,9 +20,10 @@ Related Python bugs:
- https://bugs.python.org/issue26617 - https://bugs.python.org/issue26617
""" """
import sys import sys
from typing import Any
def patch_weakref_tasks(): def patch_weakref_tasks() -> None:
"""Replace weakref.WeakSet to address Python 3 bug.""" """Replace weakref.WeakSet to address Python 3 bug."""
# pylint: disable=no-self-use, protected-access, bare-except # pylint: disable=no-self-use, protected-access, bare-except
import asyncio.tasks import asyncio.tasks
@ -30,7 +31,7 @@ def patch_weakref_tasks():
class IgnoreCalls: class IgnoreCalls:
"""Ignore add calls.""" """Ignore add calls."""
def add(self, other): def add(self, other: Any) -> None:
"""No-op add.""" """No-op add."""
return return
@ -41,7 +42,7 @@ def patch_weakref_tasks():
pass pass
def disable_c_asyncio(): def disable_c_asyncio() -> None:
"""Disable using C implementation of asyncio. """Disable using C implementation of asyncio.
Required to be able to apply the weakref monkey patch. Required to be able to apply the weakref monkey patch.
@ -53,12 +54,12 @@ def disable_c_asyncio():
PATH_TRIGGER = '_asyncio' PATH_TRIGGER = '_asyncio'
def __init__(self, path_entry): def __init__(self, path_entry: str) -> None:
if path_entry != self.PATH_TRIGGER: if path_entry != self.PATH_TRIGGER:
raise ImportError() raise ImportError()
return return
def find_module(self, fullname, path=None): def find_module(self, fullname: str, path: Any = None) -> None:
"""Find a module.""" """Find a module."""
if fullname == self.PATH_TRIGGER: if fullname == self.PATH_TRIGGER:
# We lint in Py35, exception is introduced in Py36 # We lint in Py35, exception is introduced in Py36

View file

@ -13,7 +13,7 @@ import json
import logging import logging
import urllib.parse import urllib.parse
from typing import Optional from typing import Optional, Dict, Any, List
from aiohttp.hdrs import METH_GET, METH_POST, METH_DELETE, CONTENT_TYPE from aiohttp.hdrs import METH_GET, METH_POST, METH_DELETE, CONTENT_TYPE
import requests import requests
@ -62,7 +62,7 @@ class API:
if port is not None: if port is not None:
self.base_url += ':{}'.format(port) self.base_url += ':{}'.format(port)
self.status = None self.status = None # type: Optional[APIStatus]
self._headers = {CONTENT_TYPE: CONTENT_TYPE_JSON} self._headers = {CONTENT_TYPE: CONTENT_TYPE_JSON}
if api_password is not None: if api_password is not None:
@ -75,20 +75,24 @@ class API:
return self.status == APIStatus.OK return self.status == APIStatus.OK
def __call__(self, method, path, data=None, timeout=5): def __call__(self, method: str, path: str, data: Dict = None,
timeout: int = 5) -> requests.Response:
"""Make a call to the Home Assistant API.""" """Make a call to the Home Assistant API."""
if data is not None: if data is None:
data = json.dumps(data, cls=JSONEncoder) data_str = None
else:
data_str = json.dumps(data, cls=JSONEncoder)
url = urllib.parse.urljoin(self.base_url, path) url = urllib.parse.urljoin(self.base_url, path)
try: try:
if method == METH_GET: if method == METH_GET:
return requests.get( return requests.get(
url, params=data, timeout=timeout, headers=self._headers) url, params=data_str, timeout=timeout,
headers=self._headers)
return requests.request( return requests.request(
method, url, data=data, timeout=timeout, method, url, data=data_str, timeout=timeout,
headers=self._headers) headers=self._headers)
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
@ -110,7 +114,7 @@ class JSONEncoder(json.JSONEncoder):
"""JSONEncoder that supports Home Assistant objects.""" """JSONEncoder that supports Home Assistant objects."""
# pylint: disable=method-hidden # pylint: disable=method-hidden
def default(self, o): def default(self, o: Any) -> Any:
"""Convert Home Assistant objects. """Convert Home Assistant objects.
Hand other objects to the original method. Hand other objects to the original method.
@ -125,7 +129,7 @@ class JSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o) return json.JSONEncoder.default(self, o)
def validate_api(api): def validate_api(api: API) -> APIStatus:
"""Make a call to validate API.""" """Make a call to validate API."""
try: try:
req = api(METH_GET, URL_API) req = api(METH_GET, URL_API)
@ -142,12 +146,12 @@ def validate_api(api):
return APIStatus.CANNOT_CONNECT return APIStatus.CANNOT_CONNECT
def get_event_listeners(api): def get_event_listeners(api: API) -> Dict:
"""List of events that is being listened for.""" """List of events that is being listened for."""
try: try:
req = api(METH_GET, URL_API_EVENTS) req = api(METH_GET, URL_API_EVENTS)
return req.json() if req.status_code == 200 else {} return req.json() if req.status_code == 200 else {} # type: ignore
except (HomeAssistantError, ValueError): except (HomeAssistantError, ValueError):
# ValueError if req.json() can't parse the json # ValueError if req.json() can't parse the json
@ -156,7 +160,7 @@ def get_event_listeners(api):
return {} return {}
def fire_event(api, event_type, data=None): def fire_event(api: API, event_type: str, data: Dict = None) -> None:
"""Fire an event at remote API.""" """Fire an event at remote API."""
try: try:
req = api(METH_POST, URL_API_EVENTS_EVENT.format(event_type), data) req = api(METH_POST, URL_API_EVENTS_EVENT.format(event_type), data)
@ -169,7 +173,7 @@ def fire_event(api, event_type, data=None):
_LOGGER.exception("Error firing event") _LOGGER.exception("Error firing event")
def get_state(api, entity_id): def get_state(api: API, entity_id: str) -> Optional[ha.State]:
"""Query given API for state of entity_id.""" """Query given API for state of entity_id."""
try: try:
req = api(METH_GET, URL_API_STATES_ENTITY.format(entity_id)) req = api(METH_GET, URL_API_STATES_ENTITY.format(entity_id))
@ -186,7 +190,7 @@ def get_state(api, entity_id):
return None return None
def get_states(api): def get_states(api: API) -> List[ha.State]:
"""Query given API for all states.""" """Query given API for all states."""
try: try:
req = api(METH_GET, req = api(METH_GET,
@ -202,7 +206,7 @@ def get_states(api):
return [] return []
def remove_state(api, entity_id): def remove_state(api: API, entity_id: str) -> bool:
"""Call API to remove state for entity_id. """Call API to remove state for entity_id.
Return True if entity is gone (removed/never existed). Return True if entity is gone (removed/never existed).
@ -222,7 +226,8 @@ def remove_state(api, entity_id):
return False return False
def set_state(api, entity_id, new_state, attributes=None, force_update=False): def set_state(api: API, entity_id: str, new_state: str,
attributes: Dict = None, force_update: bool = False) -> bool:
"""Tell API to update state for entity_id. """Tell API to update state for entity_id.
Return True if success. Return True if success.
@ -249,14 +254,14 @@ def set_state(api, entity_id, new_state, attributes=None, force_update=False):
return False return False
def is_state(api, entity_id, state): def is_state(api: API, entity_id: str, state: str) -> bool:
"""Query API to see if entity_id is specified state.""" """Query API to see if entity_id is specified state."""
cur_state = get_state(api, entity_id) cur_state = get_state(api, entity_id)
return cur_state and cur_state.state == state return bool(cur_state and cur_state.state == state)
def get_services(api): def get_services(api: API) -> Dict:
"""Return a list of dicts. """Return a list of dicts.
Each dict has a string "domain" and a list of strings "services". Each dict has a string "domain" and a list of strings "services".
@ -264,7 +269,7 @@ def get_services(api):
try: try:
req = api(METH_GET, URL_API_SERVICES) req = api(METH_GET, URL_API_SERVICES)
return req.json() if req.status_code == 200 else {} return req.json() if req.status_code == 200 else {} # type: ignore
except (HomeAssistantError, ValueError): except (HomeAssistantError, ValueError):
# ValueError if req.json() can't parse the json # ValueError if req.json() can't parse the json
@ -273,7 +278,9 @@ def get_services(api):
return {} return {}
def call_service(api, domain, service, service_data=None, timeout=5): def call_service(api: API, domain: str, service: str,
service_data: Dict = None,
timeout: int = 5) -> None:
"""Call a service at the remote API.""" """Call a service at the remote API."""
try: try:
req = api(METH_POST, req = api(METH_POST,
@ -288,7 +295,7 @@ def call_service(api, domain, service, service_data=None, timeout=5):
_LOGGER.exception("Error calling service") _LOGGER.exception("Error calling service")
def get_config(api): def get_config(api: API) -> Dict:
"""Return configuration.""" """Return configuration."""
try: try:
req = api(METH_GET, URL_API_CONFIG) req = api(METH_GET, URL_API_CONFIG)
@ -299,7 +306,7 @@ def get_config(api):
result = req.json() result = req.json()
if 'components' in result: if 'components' in result:
result['components'] = set(result['components']) result['components'] = set(result['components'])
return result return result # type: ignore
except (HomeAssistantError, ValueError): except (HomeAssistantError, ValueError):
# ValueError if req.json() can't parse the JSON # ValueError if req.json() can't parse the JSON

View file

@ -3,15 +3,18 @@ import asyncio
from functools import partial from functools import partial
import logging import logging
import os import os
from typing import List, Dict, Optional
import homeassistant.util.package as pkg_util import homeassistant.util.package as pkg_util
from homeassistant.core import HomeAssistant
DATA_PIP_LOCK = 'pip_lock' DATA_PIP_LOCK = 'pip_lock'
CONSTRAINT_FILE = 'package_constraints.txt' CONSTRAINT_FILE = 'package_constraints.txt'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_process_requirements(hass, name, requirements): async def async_process_requirements(hass: HomeAssistant, name: str,
requirements: List[str]) -> bool:
"""Install the requirements for a component or platform. """Install the requirements for a component or platform.
This method is a coroutine. This method is a coroutine.
@ -25,7 +28,7 @@ async def async_process_requirements(hass, name, requirements):
async with pip_lock: async with pip_lock:
for req in requirements: for req in requirements:
ret = await hass.async_add_job(pip_install, req) ret = await hass.async_add_executor_job(pip_install, req)
if not ret: if not ret:
_LOGGER.error("Not initializing %s because could not install " _LOGGER.error("Not initializing %s because could not install "
"requirement %s", name, req) "requirement %s", name, req)
@ -34,11 +37,11 @@ async def async_process_requirements(hass, name, requirements):
return True return True
def pip_kwargs(config_dir): def pip_kwargs(config_dir: Optional[str]) -> Dict[str, str]:
"""Return keyword arguments for PIP install.""" """Return keyword arguments for PIP install."""
kwargs = { kwargs = {
'constraints': os.path.join(os.path.dirname(__file__), CONSTRAINT_FILE) 'constraints': os.path.join(os.path.dirname(__file__), CONSTRAINT_FILE)
} }
if not pkg_util.is_virtual_env(): if not (config_dir is None or pkg_util.is_virtual_env()):
kwargs['target'] = os.path.join(config_dir, 'deps') kwargs['target'] = os.path.join(config_dir, 'deps')
return kwargs return kwargs

View file

@ -4,7 +4,7 @@ import logging.handlers
from timeit import default_timer as timer from timeit import default_timer as timer
from types import ModuleType from types import ModuleType
from typing import Optional, Dict from typing import Optional, Dict, List
from homeassistant import requirements, core, loader, config as conf_util from homeassistant import requirements, core, loader, config as conf_util
from homeassistant.config import async_notify_setup_error from homeassistant.config import async_notify_setup_error
@ -56,7 +56,9 @@ async def async_setup_component(hass: core.HomeAssistant, domain: str,
return await task # type: ignore return await task # type: ignore
async def _async_process_dependencies(hass, config, name, dependencies): async def _async_process_dependencies(
hass: core.HomeAssistant, config: Dict, name: str,
dependencies: List[str]) -> bool:
"""Ensure all dependencies are set up.""" """Ensure all dependencies are set up."""
blacklisted = [dep for dep in dependencies blacklisted = [dep for dep in dependencies
if dep in loader.DEPENDENCY_BLACKLIST] if dep in loader.DEPENDENCY_BLACKLIST]
@ -88,12 +90,12 @@ async def _async_process_dependencies(hass, config, name, dependencies):
async def _async_setup_component(hass: core.HomeAssistant, async def _async_setup_component(hass: core.HomeAssistant,
domain: str, config) -> bool: domain: str, config: Dict) -> bool:
"""Set up a component for Home Assistant. """Set up a component for Home Assistant.
This method is a coroutine. This method is a coroutine.
""" """
def log_error(msg, link=True): def log_error(msg: str, link: bool = True) -> None:
"""Log helper.""" """Log helper."""
_LOGGER.error("Setup failed for %s: %s", domain, msg) _LOGGER.error("Setup failed for %s: %s", domain, msg)
async_notify_setup_error(hass, domain, link) async_notify_setup_error(hass, domain, link)
@ -181,7 +183,7 @@ async def _async_setup_component(hass: core.HomeAssistant,
return True return True
async def async_prepare_setup_platform(hass: core.HomeAssistant, config, async def async_prepare_setup_platform(hass: core.HomeAssistant, config: Dict,
domain: str, platform_name: str) \ domain: str, platform_name: str) \
-> Optional[ModuleType]: -> Optional[ModuleType]:
"""Load a platform and makes sure dependencies are setup. """Load a platform and makes sure dependencies are setup.
@ -190,7 +192,7 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
""" """
platform_path = PLATFORM_FORMAT.format(domain, platform_name) platform_path = PLATFORM_FORMAT.format(domain, platform_name)
def log_error(msg): def log_error(msg: str) -> None:
"""Log helper.""" """Log helper."""
_LOGGER.error("Unable to prepare setup for platform %s: %s", _LOGGER.error("Unable to prepare setup for platform %s: %s",
platform_path, msg) platform_path, msg)
@ -217,7 +219,9 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
return platform return platform
async def async_process_deps_reqs(hass, config, name, module): async def async_process_deps_reqs(
hass: core.HomeAssistant, config: Dict, name: str,
module: ModuleType) -> None:
"""Process all dependencies and requirements for a module. """Process all dependencies and requirements for a module.
Module is a Python module of either a component or platform. Module is a Python module of either a component or platform.
@ -231,14 +235,14 @@ async def async_process_deps_reqs(hass, config, name, module):
if hasattr(module, 'DEPENDENCIES'): if hasattr(module, 'DEPENDENCIES'):
dep_success = await _async_process_dependencies( dep_success = await _async_process_dependencies(
hass, config, name, module.DEPENDENCIES) hass, config, name, module.DEPENDENCIES) # type: ignore
if not dep_success: if not dep_success:
raise HomeAssistantError("Could not setup all dependencies.") raise HomeAssistantError("Could not setup all dependencies.")
if not hass.config.skip_pip and hasattr(module, 'REQUIREMENTS'): if not hass.config.skip_pip and hasattr(module, 'REQUIREMENTS'):
req_success = await requirements.async_process_requirements( req_success = await requirements.async_process_requirements(
hass, name, module.REQUIREMENTS) hass, name, module.REQUIREMENTS) # type: ignore
if not req_success: if not req_success:
raise HomeAssistantError("Could not install all requirements.") raise HomeAssistantError("Could not install all requirements.")

View file

@ -1,9 +1,8 @@
"""Helper methods for various modules.""" """Helper methods for various modules."""
import asyncio import asyncio
from collections.abc import MutableSet from datetime import datetime, timedelta
from itertools import chain from itertools import chain
import threading import threading
from datetime import datetime
import re import re
import enum import enum
import socket import socket
@ -14,12 +13,13 @@ from types import MappingProxyType
from unicodedata import normalize from unicodedata import normalize
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa
Iterable, List, Mapping) Iterable, List, Dict, Iterator, Coroutine, MutableSet)
from .dt import as_local, utcnow from .dt import as_local, utcnow
T = TypeVar('T') T = TypeVar('T')
U = TypeVar('U') U = TypeVar('U')
ENUM_T = TypeVar('ENUM_T', bound=enum.Enum)
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
@ -91,7 +91,7 @@ def ensure_unique_string(preferred_string: str, current_strings:
# Taken from: http://stackoverflow.com/a/11735897 # Taken from: http://stackoverflow.com/a/11735897
def get_local_ip(): def get_local_ip() -> str:
"""Try to determine the local IP address of the machine.""" """Try to determine the local IP address of the machine."""
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -99,7 +99,7 @@ def get_local_ip():
# Use Google Public DNS server to determine own IP # Use Google Public DNS server to determine own IP
sock.connect(('8.8.8.8', 80)) sock.connect(('8.8.8.8', 80))
return sock.getsockname()[0] return sock.getsockname()[0] # type: ignore
except socket.error: except socket.error:
try: try:
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
@ -110,7 +110,7 @@ def get_local_ip():
# Taken from http://stackoverflow.com/a/23728630 # Taken from http://stackoverflow.com/a/23728630
def get_random_string(length=10): def get_random_string(length: int = 10) -> str:
"""Return a random string with letters and digits.""" """Return a random string with letters and digits."""
generator = random.SystemRandom() generator = random.SystemRandom()
source_chars = string.ascii_letters + string.digits source_chars = string.ascii_letters + string.digits
@ -121,59 +121,59 @@ def get_random_string(length=10):
class OrderedEnum(enum.Enum): class OrderedEnum(enum.Enum):
"""Taken from Python 3.4.0 docs.""" """Taken from Python 3.4.0 docs."""
def __ge__(self, other): def __ge__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the greater than element.""" """Return the greater than element."""
if self.__class__ is other.__class__: if self.__class__ is other.__class__:
return self.value >= other.value return bool(self.value >= other.value)
return NotImplemented return NotImplemented
def __gt__(self, other): def __gt__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the greater element.""" """Return the greater element."""
if self.__class__ is other.__class__: if self.__class__ is other.__class__:
return self.value > other.value return bool(self.value > other.value)
return NotImplemented return NotImplemented
def __le__(self, other): def __le__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the lower than element.""" """Return the lower than element."""
if self.__class__ is other.__class__: if self.__class__ is other.__class__:
return self.value <= other.value return bool(self.value <= other.value)
return NotImplemented return NotImplemented
def __lt__(self, other): def __lt__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the lower element.""" """Return the lower element."""
if self.__class__ is other.__class__: if self.__class__ is other.__class__:
return self.value < other.value return bool(self.value < other.value)
return NotImplemented return NotImplemented
class OrderedSet(MutableSet): class OrderedSet(MutableSet[T]):
"""Ordered set taken from http://code.activestate.com/recipes/576694/.""" """Ordered set taken from http://code.activestate.com/recipes/576694/."""
def __init__(self, iterable=None): def __init__(self, iterable: Iterable[T] = None) -> None:
"""Initialize the set.""" """Initialize the set."""
self.end = end = [] # type: List[Any] self.end = end = [] # type: List[Any]
end += [None, end, end] # sentinel node for doubly linked list end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # type: Mapping[List, Any] # key --> [key, prev, next] self.map = {} # type: Dict[T, List] # key --> [key, prev, next]
if iterable is not None: if iterable is not None:
self |= iterable self |= iterable # type: ignore
def __len__(self): def __len__(self) -> int:
"""Return the length of the set.""" """Return the length of the set."""
return len(self.map) return len(self.map)
def __contains__(self, key): def __contains__(self, key: T) -> bool: # type: ignore
"""Check if key is in set.""" """Check if key is in set."""
return key in self.map return key in self.map
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def add(self, key): def add(self, key: T) -> None:
"""Add an element to the end of the set.""" """Add an element to the end of the set."""
if key not in self.map: if key not in self.map:
end = self.end end = self.end
curr = end[1] curr = end[1]
curr[2] = end[1] = self.map[key] = [key, curr, end] curr[2] = end[1] = self.map[key] = [key, curr, end]
def promote(self, key): def promote(self, key: T) -> None:
"""Promote element to beginning of the set, add if not there.""" """Promote element to beginning of the set, add if not there."""
if key in self.map: if key in self.map:
self.discard(key) self.discard(key)
@ -183,14 +183,14 @@ class OrderedSet(MutableSet):
curr[2] = begin[1] = self.map[key] = [key, curr, begin] curr[2] = begin[1] = self.map[key] = [key, curr, begin]
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def discard(self, key): def discard(self, key: T) -> None:
"""Discard an element from the set.""" """Discard an element from the set."""
if key in self.map: if key in self.map:
key, prev_item, next_item = self.map.pop(key) key, prev_item, next_item = self.map.pop(key)
prev_item[2] = next_item prev_item[2] = next_item
next_item[1] = prev_item next_item[1] = prev_item
def __iter__(self): def __iter__(self) -> Iterator[T]:
"""Iterate of the set.""" """Iterate of the set."""
end = self.end end = self.end
curr = end[2] curr = end[2]
@ -198,7 +198,7 @@ class OrderedSet(MutableSet):
yield curr[0] yield curr[0]
curr = curr[2] curr = curr[2]
def __reversed__(self): def __reversed__(self) -> Iterator[T]:
"""Reverse the ordering.""" """Reverse the ordering."""
end = self.end end = self.end
curr = end[1] curr = end[1]
@ -207,7 +207,7 @@ class OrderedSet(MutableSet):
curr = curr[1] curr = curr[1]
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def pop(self, last=True): def pop(self, last: bool = True) -> T:
"""Pop element of the end of the set. """Pop element of the end of the set.
Set last=False to pop from the beginning. Set last=False to pop from the beginning.
@ -216,20 +216,20 @@ class OrderedSet(MutableSet):
raise KeyError('set is empty') raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0] key = self.end[1][0] if last else self.end[2][0]
self.discard(key) self.discard(key)
return key return key # type: ignore
def update(self, *args): def update(self, *args: Any) -> None:
"""Add elements from args to the set.""" """Add elements from args to the set."""
for item in chain(*args): for item in chain(*args):
self.add(item) self.add(item)
def __repr__(self): def __repr__(self) -> str:
"""Return the representation.""" """Return the representation."""
if not self: if not self:
return '%s()' % (self.__class__.__name__,) return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self)) return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
"""Return the comparison.""" """Return the comparison."""
if isinstance(other, OrderedSet): if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other) return len(self) == len(other) and list(self) == list(other)
@ -254,20 +254,21 @@ class Throttle:
Adds a datetime attribute `last_call` to the method. Adds a datetime attribute `last_call` to the method.
""" """
def __init__(self, min_time, limit_no_throttle=None): def __init__(self, min_time: timedelta,
limit_no_throttle: timedelta = None) -> None:
"""Initialize the throttle.""" """Initialize the throttle."""
self.min_time = min_time self.min_time = min_time
self.limit_no_throttle = limit_no_throttle self.limit_no_throttle = limit_no_throttle
def __call__(self, method): def __call__(self, method: Callable) -> Callable:
"""Caller for the throttle.""" """Caller for the throttle."""
# Make sure we return a coroutine if the method is async. # Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method): if asyncio.iscoroutinefunction(method):
async def throttled_value(): async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled.""" """Stand-in function for when real func is being throttled."""
return None return None
else: else:
def throttled_value(): def throttled_value() -> None: # type: ignore
"""Stand-in function for when real func is being throttled.""" """Stand-in function for when real func is being throttled."""
return None return None
@ -288,14 +289,14 @@ class Throttle:
'.' not in method.__qualname__.split('.<locals>.')[-1]) '.' not in method.__qualname__.split('.<locals>.')[-1])
@wraps(method) @wraps(method)
def wrapper(*args, **kwargs): def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
"""Wrap that allows wrapped to be called only once per min_time. """Wrap that allows wrapped to be called only once per min_time.
If we cannot acquire the lock, it is running so return None. If we cannot acquire the lock, it is running so return None.
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
if hasattr(method, '__self__'): if hasattr(method, '__self__'):
host = method.__self__ host = getattr(method, '__self__')
elif is_func: elif is_func:
host = wrapper host = wrapper
else: else:
@ -318,7 +319,7 @@ class Throttle:
if force or utcnow() - throttle[1] > self.min_time: if force or utcnow() - throttle[1] > self.min_time:
result = method(*args, **kwargs) result = method(*args, **kwargs)
throttle[1] = utcnow() throttle[1] = utcnow()
return result return result # type: ignore
return throttled_value() return throttled_value()
finally: finally:

View file

@ -3,22 +3,25 @@ import concurrent.futures
import threading import threading
import logging import logging
from asyncio import coroutines from asyncio import coroutines
from asyncio.events import AbstractEventLoop
from asyncio.futures import Future from asyncio.futures import Future
from asyncio import ensure_future from asyncio import ensure_future
from typing import Any, Union, Coroutine, Callable, Generator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def _set_result_unless_cancelled(fut, result): def _set_result_unless_cancelled(fut: Future, result: Any) -> None:
"""Set the result only if the Future was not cancelled.""" """Set the result only if the Future was not cancelled."""
if fut.cancelled(): if fut.cancelled():
return return
fut.set_result(result) fut.set_result(result)
def _set_concurrent_future_state(concurr, source): def _set_concurrent_future_state(
concurr: concurrent.futures.Future,
source: Union[concurrent.futures.Future, Future]) -> None:
"""Copy state from a future to a concurrent.futures.Future.""" """Copy state from a future to a concurrent.futures.Future."""
assert source.done() assert source.done()
if source.cancelled(): if source.cancelled():
@ -33,7 +36,8 @@ def _set_concurrent_future_state(concurr, source):
concurr.set_result(result) concurr.set_result(result)
def _copy_future_state(source, dest): def _copy_future_state(source: Union[concurrent.futures.Future, Future],
dest: Union[concurrent.futures.Future, Future]) -> None:
"""Copy state from another Future. """Copy state from another Future.
The other Future may be a concurrent.futures.Future. The other Future may be a concurrent.futures.Future.
@ -53,7 +57,9 @@ def _copy_future_state(source, dest):
dest.set_result(result) dest.set_result(result)
def _chain_future(source, destination): def _chain_future(
source: Union[concurrent.futures.Future, Future],
destination: Union[concurrent.futures.Future, Future]) -> None:
"""Chain two futures so that when one completes, so does the other. """Chain two futures so that when one completes, so does the other.
The result (or exception) of source will be copied to destination. The result (or exception) of source will be copied to destination.
@ -74,20 +80,23 @@ def _chain_future(source, destination):
else: else:
dest_loop = None dest_loop = None
def _set_state(future, other): def _set_state(future: Union[concurrent.futures.Future, Future],
other: Union[concurrent.futures.Future, Future]) -> None:
if isinstance(future, Future): if isinstance(future, Future):
_copy_future_state(other, future) _copy_future_state(other, future)
else: else:
_set_concurrent_future_state(future, other) _set_concurrent_future_state(future, other)
def _call_check_cancel(destination): def _call_check_cancel(
destination: Union[concurrent.futures.Future, Future]) -> None:
if destination.cancelled(): if destination.cancelled():
if source_loop is None or source_loop is dest_loop: if source_loop is None or source_loop is dest_loop:
source.cancel() source.cancel()
else: else:
source_loop.call_soon_threadsafe(source.cancel) source_loop.call_soon_threadsafe(source.cancel)
def _call_set_state(source): def _call_set_state(
source: Union[concurrent.futures.Future, Future]) -> None:
if dest_loop is None or dest_loop is source_loop: if dest_loop is None or dest_loop is source_loop:
_set_state(destination, source) _set_state(destination, source)
else: else:
@ -97,7 +106,9 @@ def _chain_future(source, destination):
source.add_done_callback(_call_set_state) source.add_done_callback(_call_set_state)
def run_coroutine_threadsafe(coro, loop): def run_coroutine_threadsafe(
coro: Union[Coroutine, Generator],
loop: AbstractEventLoop) -> concurrent.futures.Future:
"""Submit a coroutine object to a given event loop. """Submit a coroutine object to a given event loop.
Return a concurrent.futures.Future to access the result. Return a concurrent.futures.Future to access the result.
@ -110,7 +121,7 @@ def run_coroutine_threadsafe(coro, loop):
raise TypeError('A coroutine object is required') raise TypeError('A coroutine object is required')
future = concurrent.futures.Future() # type: concurrent.futures.Future future = concurrent.futures.Future() # type: concurrent.futures.Future
def callback(): def callback() -> None:
"""Handle the call to the coroutine.""" """Handle the call to the coroutine."""
try: try:
_chain_future(ensure_future(coro, loop=loop), future) _chain_future(ensure_future(coro, loop=loop), future)
@ -125,7 +136,8 @@ def run_coroutine_threadsafe(coro, loop):
return future return future
def fire_coroutine_threadsafe(coro, loop): def fire_coroutine_threadsafe(coro: Coroutine,
loop: AbstractEventLoop) -> None:
"""Submit a coroutine object to a given event loop. """Submit a coroutine object to a given event loop.
This method does not provide a way to retrieve the result and This method does not provide a way to retrieve the result and
@ -139,7 +151,7 @@ def fire_coroutine_threadsafe(coro, loop):
if not coroutines.iscoroutine(coro): if not coroutines.iscoroutine(coro):
raise TypeError('A coroutine object is required: %s' % coro) raise TypeError('A coroutine object is required: %s' % coro)
def callback(): def callback() -> None:
"""Handle the firing of a coroutine.""" """Handle the firing of a coroutine."""
ensure_future(coro, loop=loop) ensure_future(coro, loop=loop)
@ -147,7 +159,8 @@ def fire_coroutine_threadsafe(coro, loop):
return return
def run_callback_threadsafe(loop, callback, *args): def run_callback_threadsafe(loop: AbstractEventLoop, callback: Callable,
*args: Any) -> concurrent.futures.Future:
"""Submit a callback object to a given event loop. """Submit a callback object to a given event loop.
Return a concurrent.futures.Future to access the result. Return a concurrent.futures.Future to access the result.
@ -158,7 +171,7 @@ def run_callback_threadsafe(loop, callback, *args):
future = concurrent.futures.Future() # type: concurrent.futures.Future future = concurrent.futures.Future() # type: concurrent.futures.Future
def run_callback(): def run_callback() -> None:
"""Run callback and store result.""" """Run callback and store result."""
try: try:
future.set_result(callback(*args)) future.set_result(callback(*args))

View file

@ -2,7 +2,7 @@
import math import math
import colorsys import colorsys
from typing import Tuple from typing import Tuple, List
# Official CSS3 colors from w3.org: # Official CSS3 colors from w3.org:
# https://www.w3.org/TR/2010/PR-css3-color-20101028/#html4 # https://www.w3.org/TR/2010/PR-css3-color-20101028/#html4
@ -162,7 +162,7 @@ COLORS = {
} }
def color_name_to_rgb(color_name): def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
"""Convert color name to RGB hex value.""" """Convert color name to RGB hex value."""
# COLORS map has no spaces in it, so make the color_name have no # COLORS map has no spaces in it, so make the color_name have no
# spaces in it as well for matching purposes # spaces in it as well for matching purposes
@ -305,7 +305,8 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
return (r, g, b) return (r, g, b)
def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]: def color_RGB_to_hsv(
iR: float, iG: float, iB: float) -> Tuple[float, float, float]:
"""Convert an rgb color to its hsv representation. """Convert an rgb color to its hsv representation.
Hue is scaled 0-360 Hue is scaled 0-360
@ -316,7 +317,7 @@ def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]:
return round(fHSV[0]*360, 3), round(fHSV[1]*100, 3), round(fHSV[2]*100, 3) return round(fHSV[0]*360, 3), round(fHSV[1]*100, 3), round(fHSV[2]*100, 3)
def color_RGB_to_hs(iR: int, iG: int, iB: int) -> Tuple[float, float]: def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]:
"""Convert an rgb color to its hs representation.""" """Convert an rgb color to its hs representation."""
return color_RGB_to_hsv(iR, iG, iB)[:2] return color_RGB_to_hsv(iR, iG, iB)[:2]
@ -340,7 +341,7 @@ def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]:
def color_xy_to_hs(vX: float, vY: float) -> Tuple[float, float]: def color_xy_to_hs(vX: float, vY: float) -> Tuple[float, float]:
"""Convert an xy color to its hs representation.""" """Convert an xy color to its hs representation."""
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY)) h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY))
return (h, s) return h, s
def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]: def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
@ -348,8 +349,7 @@ def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS)) return color_RGB_to_xy(*color_hs_to_RGB(iH, iS))
def _match_max_scale(input_colors: Tuple[int, ...], def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
output_colors: Tuple[int, ...]) -> Tuple[int, ...]:
"""Match the maximum value of the output to the input.""" """Match the maximum value of the output to the input."""
max_in = max(input_colors) max_in = max(input_colors)
max_out = max(output_colors) max_out = max(output_colors)
@ -360,7 +360,7 @@ def _match_max_scale(input_colors: Tuple[int, ...],
return tuple(int(round(i * factor)) for i in output_colors) return tuple(int(round(i * factor)) for i in output_colors)
def color_rgb_to_rgbw(r, g, b): def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
"""Convert an rgb color to an rgbw representation.""" """Convert an rgb color to an rgbw representation."""
# Calculate the white channel as the minimum of input rgb channels. # Calculate the white channel as the minimum of input rgb channels.
# Subtract the white portion from the remaining rgb channels. # Subtract the white portion from the remaining rgb channels.
@ -369,25 +369,25 @@ def color_rgb_to_rgbw(r, g, b):
# Match the output maximum value to the input. This ensures the full # Match the output maximum value to the input. This ensures the full
# channel range is used. # channel range is used.
return _match_max_scale((r, g, b), rgbw) return _match_max_scale((r, g, b), rgbw) # type: ignore
def color_rgbw_to_rgb(r, g, b, w): def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]:
"""Convert an rgbw color to an rgb representation.""" """Convert an rgbw color to an rgb representation."""
# Add the white channel back into the rgb channels. # Add the white channel back into the rgb channels.
rgb = (r + w, g + w, b + w) rgb = (r + w, g + w, b + w)
# Match the output maximum value to the input. This ensures the # Match the output maximum value to the input. This ensures the
# output doesn't overflow. # output doesn't overflow.
return _match_max_scale((r, g, b, w), rgb) return _match_max_scale((r, g, b, w), rgb) # type: ignore
def color_rgb_to_hex(r, g, b): def color_rgb_to_hex(r: int, g: int, b: int) -> str:
"""Return a RGB color from a hex color string.""" """Return a RGB color from a hex color string."""
return '{0:02x}{1:02x}{2:02x}'.format(round(r), round(g), round(b)) return '{0:02x}{1:02x}{2:02x}'.format(round(r), round(g), round(b))
def rgb_hex_to_rgb_list(hex_string): def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
"""Return an RGB color value list from a hex color string.""" """Return an RGB color value list from a hex color string."""
return [int(hex_string[i:i + len(hex_string) // 3], 16) return [int(hex_string[i:i + len(hex_string) // 3], 16)
for i in range(0, for i in range(0,
@ -395,12 +395,14 @@ def rgb_hex_to_rgb_list(hex_string):
len(hex_string) // 3)] len(hex_string) // 3)]
def color_temperature_to_hs(color_temperature_kelvin): def color_temperature_to_hs(
color_temperature_kelvin: float) -> Tuple[float, float]:
"""Return an hs color from a color temperature in Kelvin.""" """Return an hs color from a color temperature in Kelvin."""
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin)) return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
def color_temperature_to_rgb(color_temperature_kelvin): def color_temperature_to_rgb(
color_temperature_kelvin: float) -> Tuple[float, float, float]:
""" """
Return an RGB color from a color temperature in Kelvin. Return an RGB color from a color temperature in Kelvin.
@ -421,7 +423,7 @@ def color_temperature_to_rgb(color_temperature_kelvin):
blue = _get_blue(tmp_internal) blue = _get_blue(tmp_internal)
return (red, green, blue) return red, green, blue
def _bound(color_component: float, minimum: float = 0, def _bound(color_component: float, minimum: float = 0,
@ -464,11 +466,11 @@ def _get_blue(temperature: float) -> float:
return _bound(blue) return _bound(blue)
def color_temperature_mired_to_kelvin(mired_temperature): def color_temperature_mired_to_kelvin(mired_temperature: float) -> float:
"""Convert absolute mired shift to degrees kelvin.""" """Convert absolute mired shift to degrees kelvin."""
return math.floor(1000000 / mired_temperature) return math.floor(1000000 / mired_temperature)
def color_temperature_kelvin_to_mired(kelvin_temperature): def color_temperature_kelvin_to_mired(kelvin_temperature: float) -> float:
"""Convert degrees kelvin to mired shift.""" """Convert degrees kelvin to mired shift."""
return math.floor(1000000 / kelvin_temperature) return math.floor(1000000 / kelvin_temperature)

View file

@ -1,12 +1,14 @@
"""Decorator utility functions.""" """Decorator utility functions."""
from typing import Callable, TypeVar
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
class Registry(dict): class Registry(dict):
"""Registry of items.""" """Registry of items."""
def register(self, name): def register(self, name: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Return decorator to register item with a specific name.""" """Return decorator to register item with a specific name."""
def decorator(func): def decorator(func: CALLABLE_T) -> CALLABLE_T:
"""Register decorated function.""" """Register decorated function."""
self[name] = func self[name] = func
return func return func

View file

@ -71,13 +71,13 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
return dattim.astimezone(UTC) return dattim.astimezone(UTC)
def as_timestamp(dt_value): def as_timestamp(dt_value: dt.datetime) -> float:
"""Convert a date/time into a unix time (seconds since 1970).""" """Convert a date/time into a unix time (seconds since 1970)."""
if hasattr(dt_value, "timestamp"): if hasattr(dt_value, "timestamp"):
parsed_dt = dt_value parsed_dt = dt_value # type: Optional[dt.datetime]
else: else:
parsed_dt = parse_datetime(str(dt_value)) parsed_dt = parse_datetime(str(dt_value))
if not parsed_dt: if parsed_dt is None:
raise ValueError("not a valid date/time.") raise ValueError("not a valid date/time.")
return parsed_dt.timestamp() return parsed_dt.timestamp()
@ -150,7 +150,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
return None return None
def parse_time(time_str): def parse_time(time_str: str) -> Optional[dt.time]:
"""Parse a time string (00:20:00) into Time object. """Parse a time string (00:20:00) into Time object.
Return None if invalid. Return None if invalid.

View file

@ -38,7 +38,7 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \
return {} if default is None else default return {} if default is None else default
def save_json(filename: str, data: Union[List, Dict]): def save_json(filename: str, data: Union[List, Dict]) -> None:
"""Save JSON data to a file. """Save JSON data to a file.
Returns True on success. Returns True on success.

View file

@ -33,7 +33,7 @@ LocationInfo = collections.namedtuple(
'use_metric']) 'use_metric'])
def detect_location_info(): def detect_location_info() -> Optional[LocationInfo]:
"""Detect location information.""" """Detect location information."""
data = _get_freegeoip() data = _get_freegeoip()
@ -63,7 +63,7 @@ def distance(lat1: Optional[float], lon1: Optional[float],
return result * 1000 return result * 1000
def elevation(latitude, longitude): def elevation(latitude: float, longitude: float) -> int:
"""Return elevation for given latitude and longitude.""" """Return elevation for given latitude and longitude."""
try: try:
req = requests.get( req = requests.get(

View file

@ -1,7 +1,9 @@
"""Logging utilities.""" """Logging utilities."""
import asyncio import asyncio
from asyncio.events import AbstractEventLoop
import logging import logging
import threading import threading
from typing import Optional
from .async_ import run_coroutine_threadsafe from .async_ import run_coroutine_threadsafe
@ -9,12 +11,12 @@ from .async_ import run_coroutine_threadsafe
class HideSensitiveDataFilter(logging.Filter): class HideSensitiveDataFilter(logging.Filter):
"""Filter API password calls.""" """Filter API password calls."""
def __init__(self, text): def __init__(self, text: str) -> None:
"""Initialize sensitive data filter.""" """Initialize sensitive data filter."""
super().__init__() super().__init__()
self.text = text self.text = text
def filter(self, record): def filter(self, record: logging.LogRecord) -> bool:
"""Hide sensitive data in messages.""" """Hide sensitive data in messages."""
record.msg = record.msg.replace(self.text, '*******') record.msg = record.msg.replace(self.text, '*******')
@ -25,7 +27,8 @@ class HideSensitiveDataFilter(logging.Filter):
class AsyncHandler: class AsyncHandler:
"""Logging handler wrapper to add an async layer.""" """Logging handler wrapper to add an async layer."""
def __init__(self, loop, handler): def __init__(
self, loop: AbstractEventLoop, handler: logging.Handler) -> None:
"""Initialize async logging handler wrapper.""" """Initialize async logging handler wrapper."""
self.handler = handler self.handler = handler
self.loop = loop self.loop = loop
@ -45,11 +48,11 @@ class AsyncHandler:
self._thread.start() self._thread.start()
def close(self): def close(self) -> None:
"""Wrap close to handler.""" """Wrap close to handler."""
self.emit(None) self.emit(None)
async def async_close(self, blocking=False): async def async_close(self, blocking: bool = False) -> None:
"""Close the handler. """Close the handler.
When blocking=True, will wait till closed. When blocking=True, will wait till closed.
@ -60,7 +63,7 @@ class AsyncHandler:
while self._thread.is_alive(): while self._thread.is_alive():
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
def emit(self, record): def emit(self, record: Optional[logging.LogRecord]) -> None:
"""Process a record.""" """Process a record."""
ident = self.loop.__dict__.get("_thread_ident") ident = self.loop.__dict__.get("_thread_ident")
@ -71,11 +74,11 @@ class AsyncHandler:
else: else:
self.loop.call_soon_threadsafe(self._queue.put_nowait, record) self.loop.call_soon_threadsafe(self._queue.put_nowait, record)
def __repr__(self): def __repr__(self) -> str:
"""Return the string names.""" """Return the string names."""
return str(self.handler) return str(self.handler)
def _process(self): def _process(self) -> None:
"""Process log in a thread.""" """Process log in a thread."""
while True: while True:
record = run_coroutine_threadsafe( record = run_coroutine_threadsafe(
@ -87,34 +90,34 @@ class AsyncHandler:
self.handler.emit(record) self.handler.emit(record)
def createLock(self): def createLock(self) -> None:
"""Ignore lock stuff.""" """Ignore lock stuff."""
pass pass
def acquire(self): def acquire(self) -> None:
"""Ignore lock stuff.""" """Ignore lock stuff."""
pass pass
def release(self): def release(self) -> None:
"""Ignore lock stuff.""" """Ignore lock stuff."""
pass pass
@property @property
def level(self): def level(self) -> int:
"""Wrap property level to handler.""" """Wrap property level to handler."""
return self.handler.level return self.handler.level
@property @property
def formatter(self): def formatter(self) -> Optional[logging.Formatter]:
"""Wrap property formatter to handler.""" """Wrap property formatter to handler."""
return self.handler.formatter return self.handler.formatter
@property @property
def name(self): def name(self) -> str:
"""Wrap property set_name to handler.""" """Wrap property set_name to handler."""
return self.handler.get_name() return self.handler.get_name() # type: ignore
@name.setter @name.setter
def name(self, name): def name(self, name: str) -> None:
"""Wrap property get_name to handler.""" """Wrap property get_name to handler."""
self.handler.name = name self.handler.set_name(name) # type: ignore

View file

@ -16,7 +16,7 @@ _LOGGER = logging.getLogger(__name__)
INSTALL_LOCK = threading.Lock() INSTALL_LOCK = threading.Lock()
def is_virtual_env(): def is_virtual_env() -> bool:
"""Return if we run in a virtual environtment.""" """Return if we run in a virtual environtment."""
# Check supports venv && virtualenv # Check supports venv && virtualenv
return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or

View file

@ -4,7 +4,7 @@ import ssl
import certifi import certifi
def client_context(): def client_context() -> ssl.SSLContext:
"""Return an SSL context for making requests.""" """Return an SSL context for making requests."""
context = ssl.create_default_context( context = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH, purpose=ssl.Purpose.SERVER_AUTH,
@ -13,7 +13,7 @@ def client_context():
return context return context
def server_context(): def server_context() -> ssl.SSLContext:
"""Return an SSL context following the Mozilla recommendations. """Return an SSL context following the Mozilla recommendations.
TLS configuration follows the best-practice guidelines specified here: TLS configuration follows the best-practice guidelines specified here:

View file

@ -4,7 +4,7 @@ import os
import sys import sys
import fnmatch import fnmatch
from collections import OrderedDict from collections import OrderedDict
from typing import Union, List, Dict from typing import Union, List, Dict, Iterator, overload, TypeVar
import yaml import yaml
try: try:
@ -22,7 +22,10 @@ from homeassistant.exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_SECRET_NAMESPACE = 'homeassistant' _SECRET_NAMESPACE = 'homeassistant'
SECRET_YAML = 'secrets.yaml' SECRET_YAML = 'secrets.yaml'
__SECRET_CACHE = {} # type: Dict __SECRET_CACHE = {} # type: Dict[str, JSON_TYPE]
JSON_TYPE = Union[List, Dict, str]
DICT_T = TypeVar('DICT_T', bound=Dict)
class NodeListClass(list): class NodeListClass(list):
@ -37,7 +40,42 @@ class NodeStrClass(str):
pass pass
def _add_reference(obj, loader, node): # pylint: disable=too-many-ancestors
class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers."""
def compose_node(self, parent: yaml.nodes.Node,
index: int) -> yaml.nodes.Node:
"""Annotate a node with the first line it was seen."""
last_line = self.line # type: int
node = super(SafeLineLoader,
self).compose_node(parent, index) # type: yaml.nodes.Node
node.__line__ = last_line + 1 # type: ignore
return node
# pylint: disable=pointless-statement
@overload
def _add_reference(obj: Union[list, NodeListClass],
loader: yaml.SafeLoader,
node: yaml.nodes.Node) -> NodeListClass: ...
@overload # noqa: F811
def _add_reference(obj: Union[str, NodeStrClass],
loader: yaml.SafeLoader,
node: yaml.nodes.Node) -> NodeStrClass: ...
@overload # noqa: F811
def _add_reference(obj: DICT_T,
loader: yaml.SafeLoader,
node: yaml.nodes.Node) -> DICT_T: ...
# pylint: enable=pointless-statement
def _add_reference(obj, loader: SafeLineLoader, # type: ignore # noqa: F811
node: yaml.nodes.Node):
"""Add file reference information to an object.""" """Add file reference information to an object."""
if isinstance(obj, list): if isinstance(obj, list):
obj = NodeListClass(obj) obj = NodeListClass(obj)
@ -48,20 +86,7 @@ def _add_reference(obj, loader, node):
return obj return obj
# pylint: disable=too-many-ancestors def load_yaml(fname: str) -> JSON_TYPE:
class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers."""
def compose_node(self, parent: yaml.nodes.Node, index) -> yaml.nodes.Node:
"""Annotate a node with the first line it was seen."""
last_line = self.line # type: int
node = super(SafeLineLoader,
self).compose_node(parent, index) # type: yaml.nodes.Node
node.__line__ = last_line + 1 # type: ignore
return node
def load_yaml(fname: str) -> Union[List, Dict]:
"""Load a YAML file.""" """Load a YAML file."""
try: try:
with open(fname, encoding='utf-8') as conf_file: with open(fname, encoding='utf-8') as conf_file:
@ -83,12 +108,12 @@ def dump(_dict: dict) -> str:
.replace(': null\n', ':\n') .replace(': null\n', ':\n')
def save_yaml(path, data): def save_yaml(path: str, data: dict) -> None:
"""Save YAML to a file.""" """Save YAML to a file."""
# Dump before writing to not truncate the file if dumping fails # Dump before writing to not truncate the file if dumping fails
data = dump(data) str_data = dump(data)
with open(path, 'w', encoding='utf-8') as outfile: with open(path, 'w', encoding='utf-8') as outfile:
outfile.write(data) outfile.write(str_data)
def clear_secret_cache() -> None: def clear_secret_cache() -> None:
@ -100,7 +125,7 @@ def clear_secret_cache() -> None:
def _include_yaml(loader: SafeLineLoader, def _include_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node) -> Union[List, Dict]: node: yaml.nodes.Node) -> JSON_TYPE:
"""Load another YAML file and embeds it using the !include tag. """Load another YAML file and embeds it using the !include tag.
Example: Example:
@ -115,7 +140,7 @@ def _is_file_valid(name: str) -> bool:
return not name.startswith('.') return not name.startswith('.')
def _find_files(directory: str, pattern: str): def _find_files(directory: str, pattern: str) -> Iterator[str]:
"""Recursively load files in a directory.""" """Recursively load files in a directory."""
for root, dirs, files in os.walk(directory, topdown=True): for root, dirs, files in os.walk(directory, topdown=True):
dirs[:] = [d for d in dirs if _is_file_valid(d)] dirs[:] = [d for d in dirs if _is_file_valid(d)]
@ -151,7 +176,7 @@ def _include_dir_merge_named_yaml(loader: SafeLineLoader,
def _include_dir_list_yaml(loader: SafeLineLoader, def _include_dir_list_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node): node: yaml.nodes.Node) -> List[JSON_TYPE]:
"""Load multiple files from directory as a list.""" """Load multiple files from directory as a list."""
loc = os.path.join(os.path.dirname(loader.name), node.value) loc = os.path.join(os.path.dirname(loader.name), node.value)
return [load_yaml(f) for f in _find_files(loc, '*.yaml') return [load_yaml(f) for f in _find_files(loc, '*.yaml')
@ -159,11 +184,11 @@ def _include_dir_list_yaml(loader: SafeLineLoader,
def _include_dir_merge_list_yaml(loader: SafeLineLoader, def _include_dir_merge_list_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node): node: yaml.nodes.Node) -> JSON_TYPE:
"""Load multiple files from directory as a merged list.""" """Load multiple files from directory as a merged list."""
loc = os.path.join(os.path.dirname(loader.name), loc = os.path.join(os.path.dirname(loader.name),
node.value) # type: str node.value) # type: str
merged_list = [] # type: List merged_list = [] # type: List[JSON_TYPE]
for fname in _find_files(loc, '*.yaml'): for fname in _find_files(loc, '*.yaml'):
if os.path.basename(fname) == SECRET_YAML: if os.path.basename(fname) == SECRET_YAML:
continue continue
@ -202,14 +227,14 @@ def _ordered_dict(loader: SafeLineLoader,
return _add_reference(OrderedDict(nodes), loader, node) return _add_reference(OrderedDict(nodes), loader, node)
def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node): def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node) -> JSON_TYPE:
"""Add line number and file name to Load YAML sequence.""" """Add line number and file name to Load YAML sequence."""
obj, = loader.construct_yaml_seq(node) obj, = loader.construct_yaml_seq(node)
return _add_reference(obj, loader, node) return _add_reference(obj, loader, node)
def _env_var_yaml(loader: SafeLineLoader, def _env_var_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node): node: yaml.nodes.Node) -> str:
"""Load environment variables and embed it into the configuration YAML.""" """Load environment variables and embed it into the configuration YAML."""
args = node.value.split() args = node.value.split()
@ -222,7 +247,7 @@ def _env_var_yaml(loader: SafeLineLoader,
raise HomeAssistantError(node.value) raise HomeAssistantError(node.value)
def _load_secret_yaml(secret_path: str) -> Dict: def _load_secret_yaml(secret_path: str) -> JSON_TYPE:
"""Load the secrets yaml from path.""" """Load the secrets yaml from path."""
secret_path = os.path.join(secret_path, SECRET_YAML) secret_path = os.path.join(secret_path, SECRET_YAML)
if secret_path in __SECRET_CACHE: if secret_path in __SECRET_CACHE:
@ -248,7 +273,7 @@ def _load_secret_yaml(secret_path: str) -> Dict:
def _secret_yaml(loader: SafeLineLoader, def _secret_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node): node: yaml.nodes.Node) -> JSON_TYPE:
"""Load secrets and embed it into the configuration YAML.""" """Load secrets and embed it into the configuration YAML."""
secret_path = os.path.dirname(loader.name) secret_path = os.path.dirname(loader.name)
while True: while True:
@ -308,7 +333,8 @@ yaml.SafeLoader.add_constructor('!include_dir_merge_named',
# From: https://gist.github.com/miracle2k/3184458 # From: https://gist.github.com/miracle2k/3184458
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
def represent_odict(dump, tag, mapping, flow_style=None): def represent_odict(dump, tag, mapping, # type: ignore
flow_style=None) -> yaml.MappingNode:
"""Like BaseRepresenter.represent_mapping but does not issue the sort().""" """Like BaseRepresenter.represent_mapping but does not issue the sort()."""
value = [] # type: list value = [] # type: list
node = yaml.MappingNode(tag, value, flow_style=flow_style) node = yaml.MappingNode(tag, value, flow_style=flow_style)

View file

@ -2,11 +2,18 @@
check_untyped_defs = true check_untyped_defs = true
follow_imports = silent follow_imports = silent
ignore_missing_imports = true ignore_missing_imports = true
warn_incomplete_stub = true
warn_redundant_casts = true warn_redundant_casts = true
warn_return_any = true warn_return_any = true
warn_unused_configs = true warn_unused_configs = true
warn_unused_ignores = true warn_unused_ignores = true
[mypy-homeassistant.*]
disallow_untyped_defs = true
[mypy-homeassistant.config_entries]
disallow_untyped_defs = false
[mypy-homeassistant.util.yaml] [mypy-homeassistant.util.yaml]
warn_return_any = false warn_return_any = false

View file

@ -437,6 +437,8 @@ class TestAutomation(unittest.TestCase):
} }
} }
}}): }}):
with patch('homeassistant.config.find_config_file',
return_value=''):
automation.reload(self.hass) automation.reload(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
# De-flake ?! # De-flake ?!
@ -485,6 +487,8 @@ class TestAutomation(unittest.TestCase):
with patch('homeassistant.config.load_yaml_config_file', autospec=True, with patch('homeassistant.config.load_yaml_config_file', autospec=True,
return_value={automation.DOMAIN: 'not valid'}): return_value={automation.DOMAIN: 'not valid'}):
with patch('homeassistant.config.find_config_file',
return_value=''):
automation.reload(self.hass) automation.reload(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
@ -521,6 +525,8 @@ class TestAutomation(unittest.TestCase):
with patch('homeassistant.config.load_yaml_config_file', with patch('homeassistant.config.load_yaml_config_file',
side_effect=HomeAssistantError('bla')): side_effect=HomeAssistantError('bla')):
with patch('homeassistant.config.find_config_file',
return_value=''):
automation.reload(self.hass) automation.reload(self.hass)
self.hass.block_till_done() self.hass.block_till_done()

View file

@ -365,6 +365,8 @@ class TestComponentsGroup(unittest.TestCase):
'icon': 'mdi:work', 'icon': 'mdi:work',
'view': True, 'view': True,
}}}): }}}):
with patch('homeassistant.config.find_config_file',
return_value=''):
group.reload(self.hass) group.reload(self.hass)
self.hass.block_till_done() self.hass.block_till_done()

View file

@ -199,6 +199,8 @@ class TestScriptComponent(unittest.TestCase):
} }
}] }]
}}}): }}}):
with patch('homeassistant.config.find_config_file',
return_value=''):
script.reload(self.hass) script.reload(self.hass)
self.hass.block_till_done() self.hass.block_till_done()