Add typing to homeassistant/*.py and homeassistant/util/ (#15569)
* Add typing to homeassistant/*.py and homeassistant/util/ * Fix wrong merge * Restore iterable in OrderedSet * Fix tests
This commit is contained in:
parent
b7c336a687
commit
140a874917
27 changed files with 532 additions and 384 deletions
|
@ -20,7 +20,7 @@ from homeassistant.const import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def attempt_use_uvloop():
|
def attempt_use_uvloop() -> None:
|
||||||
"""Attempt to use uvloop."""
|
"""Attempt to use uvloop."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@ -280,8 +280,8 @@ def setup_and_run_hass(config_dir: str,
|
||||||
# Imported here to avoid importing asyncio before monkey patch
|
# Imported here to avoid importing asyncio before monkey patch
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
|
|
||||||
def open_browser(event):
|
def open_browser(_: Any) -> None:
|
||||||
"""Open the webinterface in a browser."""
|
"""Open the web interface in a browser."""
|
||||||
if hass.config.api is not None: # type: ignore
|
if hass.config.api is not None: # type: ignore
|
||||||
import webbrowser
|
import webbrowser
|
||||||
webbrowser.open(hass.config.api.base_url) # type: ignore
|
webbrowser.open(hass.config.api.base_url) # type: ignore
|
||||||
|
|
|
@ -221,8 +221,8 @@ async def async_from_config_file(config_path: str,
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_enable_logging(hass: core.HomeAssistant,
|
def async_enable_logging(hass: core.HomeAssistant,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
log_rotate_days=None,
|
log_rotate_days: Optional[int] = None,
|
||||||
log_file=None,
|
log_file: Optional[str] = None,
|
||||||
log_no_color: bool = False) -> None:
|
log_no_color: bool = False) -> None:
|
||||||
"""Set up the logging.
|
"""Set up the logging.
|
||||||
|
|
||||||
|
@ -291,7 +291,7 @@ def async_enable_logging(hass: core.HomeAssistant,
|
||||||
|
|
||||||
async_handler = AsyncHandler(hass.loop, err_handler)
|
async_handler = AsyncHandler(hass.loop, err_handler)
|
||||||
|
|
||||||
async def async_stop_async_handler(event):
|
async def async_stop_async_handler(_: Any) -> None:
|
||||||
"""Cleanup async handler."""
|
"""Cleanup async handler."""
|
||||||
logging.getLogger('').removeHandler(async_handler) # type: ignore
|
logging.getLogger('').removeHandler(async_handler) # type: ignore
|
||||||
await async_handler.async_close(blocking=True)
|
await async_handler.async_close(blocking=True)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import logging
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
from typing import Optional
|
||||||
from homeassistant.auth.util import generate_secret
|
from homeassistant.auth.util import generate_secret
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
|
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
|
||||||
|
@ -241,7 +241,7 @@ class RachioIro:
|
||||||
# Only enabled zones
|
# Only enabled zones
|
||||||
return [z for z in self._zones if z[KEY_ENABLED]]
|
return [z for z in self._zones if z[KEY_ENABLED]]
|
||||||
|
|
||||||
def get_zone(self, zone_id) -> dict or None:
|
def get_zone(self, zone_id) -> Optional[dict]:
|
||||||
"""Return the zone with the given ID."""
|
"""Return the zone with the given ID."""
|
||||||
for zone in self.list_zones(include_disabled=True):
|
for zone in self.list_zones(include_disabled=True):
|
||||||
if zone[KEY_ID] == zone_id:
|
if zone[KEY_ID] == zone_id:
|
||||||
|
|
|
@ -7,8 +7,9 @@ import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from typing import Any, Tuple, Optional # noqa: F401
|
from typing import ( # noqa: F401
|
||||||
|
Any, Tuple, Optional, Dict, List, Union, Callable)
|
||||||
|
from types import ModuleType
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
|
@ -21,7 +22,7 @@ from homeassistant.const import (
|
||||||
CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS,
|
CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS,
|
||||||
__version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB,
|
__version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB,
|
||||||
CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS, CONF_TYPE)
|
CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS, CONF_TYPE)
|
||||||
from homeassistant.core import callback, DOMAIN as CONF_CORE
|
from homeassistant.core import callback, DOMAIN as CONF_CORE, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.loader import get_component, get_platform
|
from homeassistant.loader import get_component, get_platform
|
||||||
from homeassistant.util.yaml import load_yaml, SECRET_YAML
|
from homeassistant.util.yaml import load_yaml, SECRET_YAML
|
||||||
|
@ -193,7 +194,7 @@ def ensure_config_exists(config_dir: str, detect_location: bool = True)\
|
||||||
return config_path
|
return config_path
|
||||||
|
|
||||||
|
|
||||||
def create_default_config(config_dir: str, detect_location=True)\
|
def create_default_config(config_dir: str, detect_location: bool = True)\
|
||||||
-> Optional[str]:
|
-> Optional[str]:
|
||||||
"""Create a default configuration file in given configuration directory.
|
"""Create a default configuration file in given configuration directory.
|
||||||
|
|
||||||
|
@ -276,7 +277,7 @@ def create_default_config(config_dir: str, detect_location=True)\
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def async_hass_config_yaml(hass):
|
async def async_hass_config_yaml(hass: HomeAssistant) -> Dict:
|
||||||
"""Load YAML from a Home Assistant configuration file.
|
"""Load YAML from a Home Assistant configuration file.
|
||||||
|
|
||||||
This function allow a component inside the asyncio loop to reload its
|
This function allow a component inside the asyncio loop to reload its
|
||||||
|
@ -284,23 +285,26 @@ async def async_hass_config_yaml(hass):
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
def _load_hass_yaml_config():
|
def _load_hass_yaml_config() -> Dict:
|
||||||
path = find_config_file(hass.config.config_dir)
|
path = find_config_file(hass.config.config_dir)
|
||||||
conf = load_yaml_config_file(path)
|
if path is None:
|
||||||
return conf
|
raise HomeAssistantError(
|
||||||
|
"Config file not found in: {}".format(hass.config.config_dir))
|
||||||
|
return load_yaml_config_file(path)
|
||||||
|
|
||||||
conf = await hass.async_add_job(_load_hass_yaml_config)
|
return await hass.async_add_executor_job(_load_hass_yaml_config)
|
||||||
return conf
|
|
||||||
|
|
||||||
|
|
||||||
def find_config_file(config_dir: str) -> Optional[str]:
|
def find_config_file(config_dir: Optional[str]) -> Optional[str]:
|
||||||
"""Look in given directory for supported configuration files."""
|
"""Look in given directory for supported configuration files."""
|
||||||
|
if config_dir is None:
|
||||||
|
return None
|
||||||
config_path = os.path.join(config_dir, YAML_CONFIG_FILE)
|
config_path = os.path.join(config_dir, YAML_CONFIG_FILE)
|
||||||
|
|
||||||
return config_path if os.path.isfile(config_path) else None
|
return config_path if os.path.isfile(config_path) else None
|
||||||
|
|
||||||
|
|
||||||
def load_yaml_config_file(config_path):
|
def load_yaml_config_file(config_path: str) -> Dict[Any, Any]:
|
||||||
"""Parse a YAML configuration file.
|
"""Parse a YAML configuration file.
|
||||||
|
|
||||||
This method needs to run in an executor.
|
This method needs to run in an executor.
|
||||||
|
@ -323,7 +327,7 @@ def load_yaml_config_file(config_path):
|
||||||
return conf_dict
|
return conf_dict
|
||||||
|
|
||||||
|
|
||||||
def process_ha_config_upgrade(hass):
|
def process_ha_config_upgrade(hass: HomeAssistant) -> None:
|
||||||
"""Upgrade configuration if necessary.
|
"""Upgrade configuration if necessary.
|
||||||
|
|
||||||
This method needs to run in an executor.
|
This method needs to run in an executor.
|
||||||
|
@ -360,7 +364,8 @@ def process_ha_config_upgrade(hass):
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_log_exception(ex, domain, config, hass):
|
def async_log_exception(ex: vol.Invalid, domain: str, config: Dict,
|
||||||
|
hass: HomeAssistant) -> None:
|
||||||
"""Log an error for configuration validation.
|
"""Log an error for configuration validation.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -371,7 +376,7 @@ def async_log_exception(ex, domain, config, hass):
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _format_config_error(ex, domain, config):
|
def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str:
|
||||||
"""Generate log exception for configuration validation.
|
"""Generate log exception for configuration validation.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -396,7 +401,8 @@ def _format_config_error(ex, domain, config):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
async def async_process_ha_core_config(hass, config):
|
async def async_process_ha_core_config(
|
||||||
|
hass: HomeAssistant, config: Dict) -> None:
|
||||||
"""Process the [homeassistant] section from the configuration.
|
"""Process the [homeassistant] section from the configuration.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -405,12 +411,12 @@ async def async_process_ha_core_config(hass, config):
|
||||||
|
|
||||||
# Only load auth during startup.
|
# Only load auth during startup.
|
||||||
if not hasattr(hass, 'auth'):
|
if not hasattr(hass, 'auth'):
|
||||||
hass.auth = await auth.auth_manager_from_config(
|
setattr(hass, 'auth', await auth.auth_manager_from_config(
|
||||||
hass, config.get(CONF_AUTH_PROVIDERS, []))
|
hass, config.get(CONF_AUTH_PROVIDERS, [])))
|
||||||
|
|
||||||
hac = hass.config
|
hac = hass.config
|
||||||
|
|
||||||
def set_time_zone(time_zone_str):
|
def set_time_zone(time_zone_str: Optional[str]) -> None:
|
||||||
"""Help to set the time zone."""
|
"""Help to set the time zone."""
|
||||||
if time_zone_str is None:
|
if time_zone_str is None:
|
||||||
return
|
return
|
||||||
|
@ -430,11 +436,10 @@ async def async_process_ha_core_config(hass, config):
|
||||||
if key in config:
|
if key in config:
|
||||||
setattr(hac, attr, config[key])
|
setattr(hac, attr, config[key])
|
||||||
|
|
||||||
if CONF_TIME_ZONE in config:
|
set_time_zone(config.get(CONF_TIME_ZONE))
|
||||||
set_time_zone(config.get(CONF_TIME_ZONE))
|
|
||||||
|
|
||||||
# Init whitelist external dir
|
# Init whitelist external dir
|
||||||
hac.whitelist_external_dirs = set((hass.config.path('www'),))
|
hac.whitelist_external_dirs = {hass.config.path('www')}
|
||||||
if CONF_WHITELIST_EXTERNAL_DIRS in config:
|
if CONF_WHITELIST_EXTERNAL_DIRS in config:
|
||||||
hac.whitelist_external_dirs.update(
|
hac.whitelist_external_dirs.update(
|
||||||
set(config[CONF_WHITELIST_EXTERNAL_DIRS]))
|
set(config[CONF_WHITELIST_EXTERNAL_DIRS]))
|
||||||
|
@ -484,12 +489,12 @@ async def async_process_ha_core_config(hass, config):
|
||||||
hac.time_zone, hac.elevation):
|
hac.time_zone, hac.elevation):
|
||||||
return
|
return
|
||||||
|
|
||||||
discovered = []
|
discovered = [] # type: List[Tuple[str, Any]]
|
||||||
|
|
||||||
# If we miss some of the needed values, auto detect them
|
# If we miss some of the needed values, auto detect them
|
||||||
if None in (hac.latitude, hac.longitude, hac.units,
|
if None in (hac.latitude, hac.longitude, hac.units,
|
||||||
hac.time_zone):
|
hac.time_zone):
|
||||||
info = await hass.async_add_job(
|
info = await hass.async_add_executor_job(
|
||||||
loc_util.detect_location_info)
|
loc_util.detect_location_info)
|
||||||
|
|
||||||
if info is None:
|
if info is None:
|
||||||
|
@ -515,7 +520,7 @@ async def async_process_ha_core_config(hass, config):
|
||||||
|
|
||||||
if hac.elevation is None and hac.latitude is not None and \
|
if hac.elevation is None and hac.latitude is not None and \
|
||||||
hac.longitude is not None:
|
hac.longitude is not None:
|
||||||
elevation = await hass.async_add_job(
|
elevation = await hass.async_add_executor_job(
|
||||||
loc_util.elevation, hac.latitude, hac.longitude)
|
loc_util.elevation, hac.latitude, hac.longitude)
|
||||||
hac.elevation = elevation
|
hac.elevation = elevation
|
||||||
discovered.append(('elevation', elevation))
|
discovered.append(('elevation', elevation))
|
||||||
|
@ -526,7 +531,8 @@ async def async_process_ha_core_config(hass, config):
|
||||||
", ".join('{}: {}'.format(key, val) for key, val in discovered))
|
", ".join('{}: {}'.format(key, val) for key, val in discovered))
|
||||||
|
|
||||||
|
|
||||||
def _log_pkg_error(package, component, config, message):
|
def _log_pkg_error(
|
||||||
|
package: str, component: str, config: Dict, message: str) -> None:
|
||||||
"""Log an error while merging packages."""
|
"""Log an error while merging packages."""
|
||||||
message = "Package {} setup failed. Component {} {}".format(
|
message = "Package {} setup failed. Component {} {}".format(
|
||||||
package, component, message)
|
package, component, message)
|
||||||
|
@ -539,12 +545,13 @@ def _log_pkg_error(package, component, config, message):
|
||||||
_LOGGER.error(message)
|
_LOGGER.error(message)
|
||||||
|
|
||||||
|
|
||||||
def _identify_config_schema(module):
|
def _identify_config_schema(module: ModuleType) -> \
|
||||||
|
Tuple[Optional[str], Optional[Dict]]:
|
||||||
"""Extract the schema and identify list or dict based."""
|
"""Extract the schema and identify list or dict based."""
|
||||||
try:
|
try:
|
||||||
schema = module.CONFIG_SCHEMA.schema[module.DOMAIN]
|
schema = module.CONFIG_SCHEMA.schema[module.DOMAIN] # type: ignore
|
||||||
except (AttributeError, KeyError):
|
except (AttributeError, KeyError):
|
||||||
return (None, None)
|
return None, None
|
||||||
t_schema = str(schema)
|
t_schema = str(schema)
|
||||||
if t_schema.startswith('{'):
|
if t_schema.startswith('{'):
|
||||||
return ('dict', schema)
|
return ('dict', schema)
|
||||||
|
@ -553,9 +560,10 @@ def _identify_config_schema(module):
|
||||||
return '', schema
|
return '', schema
|
||||||
|
|
||||||
|
|
||||||
def _recursive_merge(conf, package):
|
def _recursive_merge(
|
||||||
|
conf: Dict[str, Any], package: Dict[str, Any]) -> Union[bool, str]:
|
||||||
"""Merge package into conf, recursively."""
|
"""Merge package into conf, recursively."""
|
||||||
error = False
|
error = False # type: Union[bool, str]
|
||||||
for key, pack_conf in package.items():
|
for key, pack_conf in package.items():
|
||||||
if isinstance(pack_conf, dict):
|
if isinstance(pack_conf, dict):
|
||||||
if not pack_conf:
|
if not pack_conf:
|
||||||
|
@ -576,8 +584,8 @@ def _recursive_merge(conf, package):
|
||||||
return error
|
return error
|
||||||
|
|
||||||
|
|
||||||
def merge_packages_config(hass, config, packages,
|
def merge_packages_config(hass: HomeAssistant, config: Dict, packages: Dict,
|
||||||
_log_pkg_error=_log_pkg_error):
|
_log_pkg_error: Callable = _log_pkg_error) -> Dict:
|
||||||
"""Merge packages into the top-level configuration. Mutate config."""
|
"""Merge packages into the top-level configuration. Mutate config."""
|
||||||
# pylint: disable=too-many-nested-blocks
|
# pylint: disable=too-many-nested-blocks
|
||||||
PACKAGES_CONFIG_SCHEMA(packages)
|
PACKAGES_CONFIG_SCHEMA(packages)
|
||||||
|
@ -641,7 +649,8 @@ def merge_packages_config(hass, config, packages,
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_process_component_config(hass, config, domain):
|
def async_process_component_config(
|
||||||
|
hass: HomeAssistant, config: Dict, domain: str) -> Optional[Dict]:
|
||||||
"""Check component configuration and return processed configuration.
|
"""Check component configuration and return processed configuration.
|
||||||
|
|
||||||
Returns None on error.
|
Returns None on error.
|
||||||
|
@ -703,14 +712,14 @@ def async_process_component_config(hass, config, domain):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
async def async_check_ha_config_file(hass):
|
async def async_check_ha_config_file(hass: HomeAssistant) -> Optional[str]:
|
||||||
"""Check if Home Assistant configuration file is valid.
|
"""Check if Home Assistant configuration file is valid.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
from homeassistant.scripts.check_config import check_ha_config_file
|
from homeassistant.scripts.check_config import check_ha_config_file
|
||||||
|
|
||||||
res = await hass.async_add_job(
|
res = await hass.async_add_executor_job(
|
||||||
check_ha_config_file, hass)
|
check_ha_config_file, hass)
|
||||||
|
|
||||||
if not res.errors:
|
if not res.errors:
|
||||||
|
@ -719,7 +728,9 @@ async def async_check_ha_config_file(hass):
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_notify_setup_error(hass, component, display_link=False):
|
def async_notify_setup_error(
|
||||||
|
hass: HomeAssistant, component: str,
|
||||||
|
display_link: bool = False) -> None:
|
||||||
"""Print a persistent notification.
|
"""Print a persistent notification.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
|
|
@ -113,10 +113,10 @@ the flow from the config panel.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Set # noqa pylint: disable=unused-import
|
from typing import Set, Optional # noqa pylint: disable=unused-import
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.setup import async_setup_component, async_process_deps_reqs
|
from homeassistant.setup import async_setup_component, async_process_deps_reqs
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
|
@ -164,8 +164,9 @@ class ConfigEntry:
|
||||||
__slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source',
|
__slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source',
|
||||||
'state')
|
'state')
|
||||||
|
|
||||||
def __init__(self, version, domain, title, data, source, entry_id=None,
|
def __init__(self, version: str, domain: str, title: str, data: dict,
|
||||||
state=ENTRY_STATE_NOT_LOADED):
|
source: str, entry_id: Optional[str] = None,
|
||||||
|
state: str = ENTRY_STATE_NOT_LOADED) -> None:
|
||||||
"""Initialize a config entry."""
|
"""Initialize a config entry."""
|
||||||
# Unique id of the config entry
|
# Unique id of the config entry
|
||||||
self.entry_id = entry_id or uuid.uuid4().hex
|
self.entry_id = entry_id or uuid.uuid4().hex
|
||||||
|
@ -188,7 +189,8 @@ class ConfigEntry:
|
||||||
# State of the entry (LOADED, NOT_LOADED)
|
# State of the entry (LOADED, NOT_LOADED)
|
||||||
self.state = state
|
self.state = state
|
||||||
|
|
||||||
async def async_setup(self, hass, *, component=None):
|
async def async_setup(
|
||||||
|
self, hass: HomeAssistant, *, component=None) -> None:
|
||||||
"""Set up an entry."""
|
"""Set up an entry."""
|
||||||
if component is None:
|
if component is None:
|
||||||
component = getattr(hass.components, self.domain)
|
component = getattr(hass.components, self.domain)
|
||||||
|
|
|
@ -4,9 +4,9 @@ Core components of Home Assistant.
|
||||||
Home Assistant is a Home Automation framework for observing the state
|
Home Assistant is a Home Automation framework for observing the state
|
||||||
of entities and react to changes.
|
of entities and react to changes.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=unused-import
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -17,9 +17,10 @@ import threading
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
|
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
# pylint: disable=unused-import
|
||||||
from typing import ( # NOQA
|
from typing import ( # NOQA
|
||||||
Optional, Any, Callable, List, TypeVar, Dict, Coroutine, Set,
|
Optional, Any, Callable, List, TypeVar, Dict, Coroutine, Set,
|
||||||
TYPE_CHECKING)
|
TYPE_CHECKING, Awaitable, Iterator)
|
||||||
|
|
||||||
from async_timeout import timeout
|
from async_timeout import timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -44,11 +45,13 @@ from homeassistant.util import location
|
||||||
from homeassistant.util.unit_system import UnitSystem, METRIC_SYSTEM # NOQA
|
from homeassistant.util.unit_system import UnitSystem, METRIC_SYSTEM # NOQA
|
||||||
|
|
||||||
# Typing imports that create a circular dependency
|
# Typing imports that create a circular dependency
|
||||||
# pylint: disable=using-constant-test,unused-import
|
# pylint: disable=using-constant-test
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from homeassistant.config_entries import ConfigEntries # noqa
|
from homeassistant.config_entries import ConfigEntries # noqa
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
|
||||||
|
CALLBACK_TYPE = Callable[[], None]
|
||||||
|
|
||||||
DOMAIN = 'homeassistant'
|
DOMAIN = 'homeassistant'
|
||||||
|
|
||||||
|
@ -79,7 +82,7 @@ def valid_state(state: str) -> bool:
|
||||||
return len(state) < 256
|
return len(state) < 256
|
||||||
|
|
||||||
|
|
||||||
def callback(func: Callable[..., T]) -> Callable[..., T]:
|
def callback(func: CALLABLE_T) -> CALLABLE_T:
|
||||||
"""Annotation to mark method as safe to call from within the event loop."""
|
"""Annotation to mark method as safe to call from within the event loop."""
|
||||||
setattr(func, '_hass_callback', True)
|
setattr(func, '_hass_callback', True)
|
||||||
return func
|
return func
|
||||||
|
@ -91,7 +94,7 @@ def is_callback(func: Callable[..., Any]) -> bool:
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_loop_exception_handler(loop, context):
|
def async_loop_exception_handler(_: Any, context: Dict) -> None:
|
||||||
"""Handle all exception inside the core loop."""
|
"""Handle all exception inside the core loop."""
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
exception = context.get('exception')
|
exception = context.get('exception')
|
||||||
|
@ -119,7 +122,9 @@ class CoreState(enum.Enum):
|
||||||
class HomeAssistant:
|
class HomeAssistant:
|
||||||
"""Root object of the Home Assistant home automation."""
|
"""Root object of the Home Assistant home automation."""
|
||||||
|
|
||||||
def __init__(self, loop=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None:
|
||||||
"""Initialize new Home Assistant object."""
|
"""Initialize new Home Assistant object."""
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
self.loop = loop or asyncio.ProactorEventLoop()
|
self.loop = loop or asyncio.ProactorEventLoop()
|
||||||
|
@ -170,7 +175,7 @@ class HomeAssistant:
|
||||||
self.loop.close()
|
self.loop.close()
|
||||||
return self.exit_code
|
return self.exit_code
|
||||||
|
|
||||||
async def async_start(self):
|
async def async_start(self) -> None:
|
||||||
"""Finalize startup from inside the event loop.
|
"""Finalize startup from inside the event loop.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -178,8 +183,7 @@ class HomeAssistant:
|
||||||
_LOGGER.info("Starting Home Assistant")
|
_LOGGER.info("Starting Home Assistant")
|
||||||
self.state = CoreState.starting
|
self.state = CoreState.starting
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
setattr(self.loop, '_thread_ident', threading.get_ident())
|
||||||
self.loop._thread_ident = threading.get_ident()
|
|
||||||
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -230,7 +234,8 @@ class HomeAssistant:
|
||||||
elif asyncio.iscoroutinefunction(target):
|
elif asyncio.iscoroutinefunction(target):
|
||||||
task = self.loop.create_task(target(*args))
|
task = self.loop.create_task(target(*args))
|
||||||
else:
|
else:
|
||||||
task = self.loop.run_in_executor(None, target, *args)
|
task = self.loop.run_in_executor( # type: ignore
|
||||||
|
None, target, *args)
|
||||||
|
|
||||||
# If a task is scheduled
|
# If a task is scheduled
|
||||||
if self._track_task and task is not None:
|
if self._track_task and task is not None:
|
||||||
|
@ -256,11 +261,11 @@ class HomeAssistant:
|
||||||
@callback
|
@callback
|
||||||
def async_add_executor_job(
|
def async_add_executor_job(
|
||||||
self,
|
self,
|
||||||
target: Callable[..., Any],
|
target: Callable[..., T],
|
||||||
*args: Any) -> asyncio.Future:
|
*args: Any) -> Awaitable[T]:
|
||||||
"""Add an executor job from within the event loop."""
|
"""Add an executor job from within the event loop."""
|
||||||
task = self.loop.run_in_executor( # type: ignore
|
task = self.loop.run_in_executor(
|
||||||
None, target, *args) # type: asyncio.Future
|
None, target, *args)
|
||||||
|
|
||||||
# If a task is scheduled
|
# If a task is scheduled
|
||||||
if self._track_task:
|
if self._track_task:
|
||||||
|
@ -269,12 +274,12 @@ class HomeAssistant:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_track_tasks(self):
|
def async_track_tasks(self) -> None:
|
||||||
"""Track tasks so you can wait for all tasks to be done."""
|
"""Track tasks so you can wait for all tasks to be done."""
|
||||||
self._track_task = True
|
self._track_task = True
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_stop_track_tasks(self):
|
def async_stop_track_tasks(self) -> None:
|
||||||
"""Stop track tasks so you can't wait for all tasks to be done."""
|
"""Stop track tasks so you can't wait for all tasks to be done."""
|
||||||
self._track_task = False
|
self._track_task = False
|
||||||
|
|
||||||
|
@ -297,7 +302,7 @@ class HomeAssistant:
|
||||||
run_coroutine_threadsafe(
|
run_coroutine_threadsafe(
|
||||||
self.async_block_till_done(), loop=self.loop).result()
|
self.async_block_till_done(), loop=self.loop).result()
|
||||||
|
|
||||||
async def async_block_till_done(self):
|
async def async_block_till_done(self) -> None:
|
||||||
"""Block till all pending work is done."""
|
"""Block till all pending work is done."""
|
||||||
# To flush out any call_soon_threadsafe
|
# To flush out any call_soon_threadsafe
|
||||||
await asyncio.sleep(0, loop=self.loop)
|
await asyncio.sleep(0, loop=self.loop)
|
||||||
|
@ -342,9 +347,9 @@ class EventOrigin(enum.Enum):
|
||||||
local = 'LOCAL'
|
local = 'LOCAL'
|
||||||
remote = 'REMOTE'
|
remote = 'REMOTE'
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
"""Return the event."""
|
"""Return the event."""
|
||||||
return self.value
|
return self.value # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Event:
|
class Event:
|
||||||
|
@ -352,15 +357,16 @@ class Event:
|
||||||
|
|
||||||
__slots__ = ['event_type', 'data', 'origin', 'time_fired']
|
__slots__ = ['event_type', 'data', 'origin', 'time_fired']
|
||||||
|
|
||||||
def __init__(self, event_type, data=None, origin=EventOrigin.local,
|
def __init__(self, event_type: str, data: Optional[Dict] = None,
|
||||||
time_fired=None):
|
origin: EventOrigin = EventOrigin.local,
|
||||||
|
time_fired: Optional[int] = None) -> None:
|
||||||
"""Initialize a new event."""
|
"""Initialize a new event."""
|
||||||
self.event_type = event_type
|
self.event_type = event_type
|
||||||
self.data = data or {}
|
self.data = data or {}
|
||||||
self.origin = origin
|
self.origin = origin
|
||||||
self.time_fired = time_fired or dt_util.utcnow()
|
self.time_fired = time_fired or dt_util.utcnow()
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> Dict:
|
||||||
"""Create a dict representation of this Event.
|
"""Create a dict representation of this Event.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -372,7 +378,7 @@ class Event:
|
||||||
'time_fired': self.time_fired,
|
'time_fired': self.time_fired,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Return the representation."""
|
"""Return the representation."""
|
||||||
# pylint: disable=maybe-no-member
|
# pylint: disable=maybe-no-member
|
||||||
if self.data:
|
if self.data:
|
||||||
|
@ -383,9 +389,9 @@ class Event:
|
||||||
return "<Event {}[{}]>".format(self.event_type,
|
return "<Event {}[{}]>".format(self.event_type,
|
||||||
str(self.origin)[0])
|
str(self.origin)[0])
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Return the comparison."""
|
"""Return the comparison."""
|
||||||
return (self.__class__ == other.__class__ and
|
return (self.__class__ == other.__class__ and # type: ignore
|
||||||
self.event_type == other.event_type and
|
self.event_type == other.event_type and
|
||||||
self.data == other.data and
|
self.data == other.data and
|
||||||
self.origin == other.origin and
|
self.origin == other.origin and
|
||||||
|
@ -401,7 +407,7 @@ class EventBus:
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_listeners(self):
|
def async_listeners(self) -> Dict[str, int]:
|
||||||
"""Return dictionary with events and the number of listeners.
|
"""Return dictionary with events and the number of listeners.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -410,20 +416,21 @@ class EventBus:
|
||||||
for key in self._listeners}
|
for key in self._listeners}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def listeners(self):
|
def listeners(self) -> Dict[str, int]:
|
||||||
"""Return dictionary with events and the number of listeners."""
|
"""Return dictionary with events and the number of listeners."""
|
||||||
return run_callback_threadsafe(
|
return run_callback_threadsafe( # type: ignore
|
||||||
self._hass.loop, self.async_listeners
|
self._hass.loop, self.async_listeners
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
|
def fire(self, event_type: str, event_data: Optional[Dict] = None,
|
||||||
|
origin: EventOrigin = EventOrigin.local) -> None:
|
||||||
"""Fire an event."""
|
"""Fire an event."""
|
||||||
self._hass.loop.call_soon_threadsafe(
|
self._hass.loop.call_soon_threadsafe(
|
||||||
self.async_fire, event_type, event_data, origin)
|
self.async_fire, event_type, event_data, origin)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_fire(self, event_type: str, event_data=None,
|
def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
|
||||||
origin=EventOrigin.local):
|
origin: EventOrigin = EventOrigin.local) -> None:
|
||||||
"""Fire an event.
|
"""Fire an event.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -447,7 +454,8 @@ class EventBus:
|
||||||
for func in listeners:
|
for func in listeners:
|
||||||
self._hass.async_add_job(func, event)
|
self._hass.async_add_job(func, event)
|
||||||
|
|
||||||
def listen(self, event_type, listener):
|
def listen(
|
||||||
|
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
||||||
"""Listen for all events or events of a specific type.
|
"""Listen for all events or events of a specific type.
|
||||||
|
|
||||||
To listen to all events specify the constant ``MATCH_ALL``
|
To listen to all events specify the constant ``MATCH_ALL``
|
||||||
|
@ -456,7 +464,7 @@ class EventBus:
|
||||||
async_remove_listener = run_callback_threadsafe(
|
async_remove_listener = run_callback_threadsafe(
|
||||||
self._hass.loop, self.async_listen, event_type, listener).result()
|
self._hass.loop, self.async_listen, event_type, listener).result()
|
||||||
|
|
||||||
def remove_listener():
|
def remove_listener() -> None:
|
||||||
"""Remove the listener."""
|
"""Remove the listener."""
|
||||||
run_callback_threadsafe(
|
run_callback_threadsafe(
|
||||||
self._hass.loop, async_remove_listener).result()
|
self._hass.loop, async_remove_listener).result()
|
||||||
|
@ -464,7 +472,8 @@ class EventBus:
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_listen(self, event_type, listener):
|
def async_listen(
|
||||||
|
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
||||||
"""Listen for all events or events of a specific type.
|
"""Listen for all events or events of a specific type.
|
||||||
|
|
||||||
To listen to all events specify the constant ``MATCH_ALL``
|
To listen to all events specify the constant ``MATCH_ALL``
|
||||||
|
@ -477,13 +486,14 @@ class EventBus:
|
||||||
else:
|
else:
|
||||||
self._listeners[event_type] = [listener]
|
self._listeners[event_type] = [listener]
|
||||||
|
|
||||||
def remove_listener():
|
def remove_listener() -> None:
|
||||||
"""Remove the listener."""
|
"""Remove the listener."""
|
||||||
self._async_remove_listener(event_type, listener)
|
self._async_remove_listener(event_type, listener)
|
||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
def listen_once(self, event_type, listener):
|
def listen_once(
|
||||||
|
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
||||||
"""Listen once for event of a specific type.
|
"""Listen once for event of a specific type.
|
||||||
|
|
||||||
To listen to all events specify the constant ``MATCH_ALL``
|
To listen to all events specify the constant ``MATCH_ALL``
|
||||||
|
@ -495,7 +505,7 @@ class EventBus:
|
||||||
self._hass.loop, self.async_listen_once, event_type, listener,
|
self._hass.loop, self.async_listen_once, event_type, listener,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
def remove_listener():
|
def remove_listener() -> None:
|
||||||
"""Remove the listener."""
|
"""Remove the listener."""
|
||||||
run_callback_threadsafe(
|
run_callback_threadsafe(
|
||||||
self._hass.loop, async_remove_listener).result()
|
self._hass.loop, async_remove_listener).result()
|
||||||
|
@ -503,7 +513,8 @@ class EventBus:
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_listen_once(self, event_type, listener):
|
def async_listen_once(
|
||||||
|
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
||||||
"""Listen once for event of a specific type.
|
"""Listen once for event of a specific type.
|
||||||
|
|
||||||
To listen to all events specify the constant ``MATCH_ALL``
|
To listen to all events specify the constant ``MATCH_ALL``
|
||||||
|
@ -514,8 +525,8 @@ class EventBus:
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
@callback
|
@callback
|
||||||
def onetime_listener(event):
|
def onetime_listener(event: Event) -> None:
|
||||||
"""Remove listener from eventbus and then fire listener."""
|
"""Remove listener from event bus and then fire listener."""
|
||||||
if hasattr(onetime_listener, 'run'):
|
if hasattr(onetime_listener, 'run'):
|
||||||
return
|
return
|
||||||
# Set variable so that we will never run twice.
|
# Set variable so that we will never run twice.
|
||||||
|
@ -530,7 +541,8 @@ class EventBus:
|
||||||
return self.async_listen(event_type, onetime_listener)
|
return self.async_listen(event_type, onetime_listener)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_remove_listener(self, event_type, listener):
|
def _async_remove_listener(
|
||||||
|
self, event_type: str, listener: Callable) -> None:
|
||||||
"""Remove a listener of a specific event_type.
|
"""Remove a listener of a specific event_type.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -560,8 +572,10 @@ class State:
|
||||||
__slots__ = ['entity_id', 'state', 'attributes',
|
__slots__ = ['entity_id', 'state', 'attributes',
|
||||||
'last_changed', 'last_updated']
|
'last_changed', 'last_updated']
|
||||||
|
|
||||||
def __init__(self, entity_id, state, attributes=None, last_changed=None,
|
def __init__(self, entity_id: str, state: Any,
|
||||||
last_updated=None):
|
attributes: Optional[Dict] = None,
|
||||||
|
last_changed: Optional[datetime.datetime] = None,
|
||||||
|
last_updated: Optional[datetime.datetime] = None) -> None:
|
||||||
"""Initialize a new state."""
|
"""Initialize a new state."""
|
||||||
state = str(state)
|
state = str(state)
|
||||||
|
|
||||||
|
@ -582,23 +596,23 @@ class State:
|
||||||
self.last_changed = last_changed or self.last_updated
|
self.last_changed = last_changed or self.last_updated
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def domain(self):
|
def domain(self) -> str:
|
||||||
"""Domain of this state."""
|
"""Domain of this state."""
|
||||||
return split_entity_id(self.entity_id)[0]
|
return split_entity_id(self.entity_id)[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def object_id(self):
|
def object_id(self) -> str:
|
||||||
"""Object id of this state."""
|
"""Object id of this state."""
|
||||||
return split_entity_id(self.entity_id)[1]
|
return split_entity_id(self.entity_id)[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""Name of this state."""
|
"""Name of this state."""
|
||||||
return (
|
return (
|
||||||
self.attributes.get(ATTR_FRIENDLY_NAME) or
|
self.attributes.get(ATTR_FRIENDLY_NAME) or
|
||||||
self.object_id.replace('_', ' '))
|
self.object_id.replace('_', ' '))
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> Dict:
|
||||||
"""Return a dict representation of the State.
|
"""Return a dict representation of the State.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -613,7 +627,7 @@ class State:
|
||||||
'last_updated': self.last_updated}
|
'last_updated': self.last_updated}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, json_dict):
|
def from_dict(cls, json_dict: Dict) -> Any:
|
||||||
"""Initialize a state from a dict.
|
"""Initialize a state from a dict.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -637,14 +651,14 @@ class State:
|
||||||
return cls(json_dict['entity_id'], json_dict['state'],
|
return cls(json_dict['entity_id'], json_dict['state'],
|
||||||
json_dict.get('attributes'), last_changed, last_updated)
|
json_dict.get('attributes'), last_changed, last_updated)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Return the comparison of the state."""
|
"""Return the comparison of the state."""
|
||||||
return (self.__class__ == other.__class__ and
|
return (self.__class__ == other.__class__ and # type: ignore
|
||||||
self.entity_id == other.entity_id and
|
self.entity_id == other.entity_id and
|
||||||
self.state == other.state and
|
self.state == other.state and
|
||||||
self.attributes == other.attributes)
|
self.attributes == other.attributes)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Return the representation of the states."""
|
"""Return the representation of the states."""
|
||||||
attr = "; {}".format(util.repr_helper(self.attributes)) \
|
attr = "; {}".format(util.repr_helper(self.attributes)) \
|
||||||
if self.attributes else ""
|
if self.attributes else ""
|
||||||
|
@ -657,21 +671,23 @@ class State:
|
||||||
class StateMachine:
|
class StateMachine:
|
||||||
"""Helper class that tracks the state of different entities."""
|
"""Helper class that tracks the state of different entities."""
|
||||||
|
|
||||||
def __init__(self, bus, loop):
|
def __init__(self, bus: EventBus,
|
||||||
|
loop: asyncio.events.AbstractEventLoop) -> None:
|
||||||
"""Initialize state machine."""
|
"""Initialize state machine."""
|
||||||
self._states = {} # type: Dict[str, State]
|
self._states = {} # type: Dict[str, State]
|
||||||
self._bus = bus
|
self._bus = bus
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
|
|
||||||
def entity_ids(self, domain_filter=None):
|
def entity_ids(self, domain_filter: Optional[str] = None)-> List[str]:
|
||||||
"""List of entity ids that are being tracked."""
|
"""List of entity ids that are being tracked."""
|
||||||
future = run_callback_threadsafe(
|
future = run_callback_threadsafe(
|
||||||
self._loop, self.async_entity_ids, domain_filter
|
self._loop, self.async_entity_ids, domain_filter
|
||||||
)
|
)
|
||||||
return future.result()
|
return future.result() # type: ignore
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_entity_ids(self, domain_filter=None):
|
def async_entity_ids(
|
||||||
|
self, domain_filter: Optional[str] = None) -> List[str]:
|
||||||
"""List of entity ids that are being tracked.
|
"""List of entity ids that are being tracked.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -684,26 +700,27 @@ class StateMachine:
|
||||||
return [state.entity_id for state in self._states.values()
|
return [state.entity_id for state in self._states.values()
|
||||||
if state.domain == domain_filter]
|
if state.domain == domain_filter]
|
||||||
|
|
||||||
def all(self):
|
def all(self)-> List[State]:
|
||||||
"""Create a list of all states."""
|
"""Create a list of all states."""
|
||||||
return run_callback_threadsafe(self._loop, self.async_all).result()
|
return run_callback_threadsafe( # type: ignore
|
||||||
|
self._loop, self.async_all).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_all(self):
|
def async_all(self)-> List[State]:
|
||||||
"""Create a list of all states.
|
"""Create a list of all states.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
return list(self._states.values())
|
return list(self._states.values())
|
||||||
|
|
||||||
def get(self, entity_id):
|
def get(self, entity_id: str) -> Optional[State]:
|
||||||
"""Retrieve state of entity_id or None if not found.
|
"""Retrieve state of entity_id or None if not found.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
return self._states.get(entity_id.lower())
|
return self._states.get(entity_id.lower())
|
||||||
|
|
||||||
def is_state(self, entity_id, state):
|
def is_state(self, entity_id: str, state: State) -> bool:
|
||||||
"""Test if entity exists and is specified state.
|
"""Test if entity exists and is specified state.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -711,16 +728,16 @@ class StateMachine:
|
||||||
state_obj = self.get(entity_id)
|
state_obj = self.get(entity_id)
|
||||||
return state_obj is not None and state_obj.state == state
|
return state_obj is not None and state_obj.state == state
|
||||||
|
|
||||||
def remove(self, entity_id):
|
def remove(self, entity_id: str) -> bool:
|
||||||
"""Remove the state of an entity.
|
"""Remove the state of an entity.
|
||||||
|
|
||||||
Returns boolean to indicate if an entity was removed.
|
Returns boolean to indicate if an entity was removed.
|
||||||
"""
|
"""
|
||||||
return run_callback_threadsafe(
|
return run_callback_threadsafe( # type: ignore
|
||||||
self._loop, self.async_remove, entity_id).result()
|
self._loop, self.async_remove, entity_id).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove(self, entity_id):
|
def async_remove(self, entity_id: str) -> bool:
|
||||||
"""Remove the state of an entity.
|
"""Remove the state of an entity.
|
||||||
|
|
||||||
Returns boolean to indicate if an entity was removed.
|
Returns boolean to indicate if an entity was removed.
|
||||||
|
@ -740,7 +757,9 @@ class StateMachine:
|
||||||
})
|
})
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set(self, entity_id, new_state, attributes=None, force_update=False):
|
def set(self, entity_id: str, new_state: Any,
|
||||||
|
attributes: Optional[Dict] = None,
|
||||||
|
force_update: bool = False) -> None:
|
||||||
"""Set the state of an entity, add entity if it does not exist.
|
"""Set the state of an entity, add entity if it does not exist.
|
||||||
|
|
||||||
Attributes is an optional dict to specify attributes of this state.
|
Attributes is an optional dict to specify attributes of this state.
|
||||||
|
@ -754,8 +773,9 @@ class StateMachine:
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set(self, entity_id, new_state, attributes=None,
|
def async_set(self, entity_id: str, new_state: Any,
|
||||||
force_update=False):
|
attributes: Optional[Dict] = None,
|
||||||
|
force_update: bool = False) -> None:
|
||||||
"""Set the state of an entity, add entity if it does not exist.
|
"""Set the state of an entity, add entity if it does not exist.
|
||||||
|
|
||||||
Attributes is an optional dict to specify attributes of this state.
|
Attributes is an optional dict to specify attributes of this state.
|
||||||
|
@ -769,15 +789,19 @@ class StateMachine:
|
||||||
new_state = str(new_state)
|
new_state = str(new_state)
|
||||||
attributes = attributes or {}
|
attributes = attributes or {}
|
||||||
old_state = self._states.get(entity_id)
|
old_state = self._states.get(entity_id)
|
||||||
is_existing = old_state is not None
|
if old_state is None:
|
||||||
same_state = (is_existing and old_state.state == new_state and
|
same_state = False
|
||||||
not force_update)
|
same_attr = False
|
||||||
same_attr = is_existing and old_state.attributes == attributes
|
last_changed = None
|
||||||
|
else:
|
||||||
|
same_state = (old_state.state == new_state and
|
||||||
|
not force_update)
|
||||||
|
same_attr = old_state.attributes == attributes
|
||||||
|
last_changed = old_state.last_changed if same_state else None
|
||||||
|
|
||||||
if same_state and same_attr:
|
if same_state and same_attr:
|
||||||
return
|
return
|
||||||
|
|
||||||
last_changed = old_state.last_changed if same_state else None
|
|
||||||
state = State(entity_id, new_state, attributes, last_changed)
|
state = State(entity_id, new_state, attributes, last_changed)
|
||||||
self._states[entity_id] = state
|
self._states[entity_id] = state
|
||||||
self._bus.async_fire(EVENT_STATE_CHANGED, {
|
self._bus.async_fire(EVENT_STATE_CHANGED, {
|
||||||
|
@ -792,7 +816,7 @@ class Service:
|
||||||
|
|
||||||
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
|
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
|
||||||
|
|
||||||
def __init__(self, func, schema):
|
def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None:
|
||||||
"""Initialize a service."""
|
"""Initialize a service."""
|
||||||
self.func = func
|
self.func = func
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
@ -805,14 +829,15 @@ class ServiceCall:
|
||||||
|
|
||||||
__slots__ = ['domain', 'service', 'data', 'call_id']
|
__slots__ = ['domain', 'service', 'data', 'call_id']
|
||||||
|
|
||||||
def __init__(self, domain, service, data=None, call_id=None):
|
def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
|
||||||
|
call_id: Optional[str] = None) -> None:
|
||||||
"""Initialize a service call."""
|
"""Initialize a service call."""
|
||||||
self.domain = domain.lower()
|
self.domain = domain.lower()
|
||||||
self.service = service.lower()
|
self.service = service.lower()
|
||||||
self.data = MappingProxyType(data or {})
|
self.data = MappingProxyType(data or {})
|
||||||
self.call_id = call_id
|
self.call_id = call_id
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Return the representation of the service."""
|
"""Return the representation of the service."""
|
||||||
if self.data:
|
if self.data:
|
||||||
return "<ServiceCall {}.{}: {}>".format(
|
return "<ServiceCall {}.{}: {}>".format(
|
||||||
|
@ -824,13 +849,13 @@ class ServiceCall:
|
||||||
class ServiceRegistry:
|
class ServiceRegistry:
|
||||||
"""Offer the services over the eventbus."""
|
"""Offer the services over the eventbus."""
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize a service registry."""
|
"""Initialize a service registry."""
|
||||||
self._services = {} # type: Dict[str, Dict[str, Service]]
|
self._services = {} # type: Dict[str, Dict[str, Service]]
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._async_unsub_call_event = None
|
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
|
||||||
|
|
||||||
def _gen_unique_id():
|
def _gen_unique_id() -> Iterator[str]:
|
||||||
cur_id = 1
|
cur_id = 1
|
||||||
while True:
|
while True:
|
||||||
yield '{}-{}'.format(id(self), cur_id)
|
yield '{}-{}'.format(id(self), cur_id)
|
||||||
|
@ -840,14 +865,14 @@ class ServiceRegistry:
|
||||||
self._generate_unique_id = lambda: next(gen)
|
self._generate_unique_id = lambda: next(gen)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def services(self):
|
def services(self) -> Dict[str, Dict[str, Service]]:
|
||||||
"""Return dictionary with per domain a list of available services."""
|
"""Return dictionary with per domain a list of available services."""
|
||||||
return run_callback_threadsafe(
|
return run_callback_threadsafe( # type: ignore
|
||||||
self._hass.loop, self.async_services,
|
self._hass.loop, self.async_services,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_services(self):
|
def async_services(self) -> Dict[str, Dict[str, Service]]:
|
||||||
"""Return dictionary with per domain a list of available services.
|
"""Return dictionary with per domain a list of available services.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -855,14 +880,15 @@ class ServiceRegistry:
|
||||||
return {domain: self._services[domain].copy()
|
return {domain: self._services[domain].copy()
|
||||||
for domain in self._services}
|
for domain in self._services}
|
||||||
|
|
||||||
def has_service(self, domain, service):
|
def has_service(self, domain: str, service: str) -> bool:
|
||||||
"""Test if specified service exists.
|
"""Test if specified service exists.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
return service.lower() in self._services.get(domain.lower(), [])
|
return service.lower() in self._services.get(domain.lower(), [])
|
||||||
|
|
||||||
def register(self, domain, service, service_func, schema=None):
|
def register(self, domain: str, service: str, service_func: Callable,
|
||||||
|
schema: Optional[vol.Schema] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Register a service.
|
Register a service.
|
||||||
|
|
||||||
|
@ -874,7 +900,8 @@ class ServiceRegistry:
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register(self, domain, service, service_func, schema=None):
|
def async_register(self, domain: str, service: str, service_func: Callable,
|
||||||
|
schema: Optional[vol.Schema] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Register a service.
|
Register a service.
|
||||||
|
|
||||||
|
@ -900,13 +927,13 @@ class ServiceRegistry:
|
||||||
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove(self, domain, service):
|
def remove(self, domain: str, service: str) -> None:
|
||||||
"""Remove a registered service from service handler."""
|
"""Remove a registered service from service handler."""
|
||||||
run_callback_threadsafe(
|
run_callback_threadsafe(
|
||||||
self._hass.loop, self.async_remove, domain, service).result()
|
self._hass.loop, self.async_remove, domain, service).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove(self, domain, service):
|
def async_remove(self, domain: str, service: str) -> None:
|
||||||
"""Remove a registered service from service handler.
|
"""Remove a registered service from service handler.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -926,7 +953,9 @@ class ServiceRegistry:
|
||||||
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, domain, service, service_data=None, blocking=False):
|
def call(self, domain: str, service: str,
|
||||||
|
service_data: Optional[Dict] = None,
|
||||||
|
blocking: bool = False) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
|
|
||||||
|
@ -943,13 +972,14 @@ class ServiceRegistry:
|
||||||
Because the service is sent as an event you are not allowed to use
|
Because the service is sent as an event you are not allowed to use
|
||||||
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
||||||
"""
|
"""
|
||||||
return run_coroutine_threadsafe(
|
return run_coroutine_threadsafe( # type: ignore
|
||||||
self.async_call(domain, service, service_data, blocking),
|
self.async_call(domain, service, service_data, blocking),
|
||||||
self._hass.loop
|
self._hass.loop
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
async def async_call(self, domain, service, service_data=None,
|
async def async_call(self, domain: str, service: str,
|
||||||
blocking=False):
|
service_data: Optional[Dict] = None,
|
||||||
|
blocking: bool = False) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
|
|
||||||
|
@ -981,7 +1011,7 @@ class ServiceRegistry:
|
||||||
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future
|
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def service_executed(event):
|
def service_executed(event: Event) -> None:
|
||||||
"""Handle an executed service."""
|
"""Handle an executed service."""
|
||||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||||
fut.set_result(True)
|
fut.set_result(True)
|
||||||
|
@ -989,20 +1019,22 @@ class ServiceRegistry:
|
||||||
unsub = self._hass.bus.async_listen(
|
unsub = self._hass.bus.async_listen(
|
||||||
EVENT_SERVICE_EXECUTED, service_executed)
|
EVENT_SERVICE_EXECUTED, service_executed)
|
||||||
|
|
||||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
||||||
|
|
||||||
if blocking:
|
|
||||||
done, _ = await asyncio.wait(
|
done, _ = await asyncio.wait(
|
||||||
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
|
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
|
||||||
success = bool(done)
|
success = bool(done)
|
||||||
unsub()
|
unsub()
|
||||||
return success
|
return success
|
||||||
|
|
||||||
async def _event_to_service_call(self, event):
|
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _event_to_service_call(self, event: Event) -> None:
|
||||||
"""Handle the SERVICE_CALLED events from the EventBus."""
|
"""Handle the SERVICE_CALLED events from the EventBus."""
|
||||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||||
domain = event.data.get(ATTR_DOMAIN).lower()
|
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
||||||
service = event.data.get(ATTR_SERVICE).lower()
|
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
||||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||||
|
|
||||||
if not self.has_service(domain, service):
|
if not self.has_service(domain, service):
|
||||||
|
@ -1013,7 +1045,7 @@ class ServiceRegistry:
|
||||||
|
|
||||||
service_handler = self._services[domain][service]
|
service_handler = self._services[domain][service]
|
||||||
|
|
||||||
def fire_service_executed():
|
def fire_service_executed() -> None:
|
||||||
"""Fire service executed event."""
|
"""Fire service executed event."""
|
||||||
if not call_id:
|
if not call_id:
|
||||||
return
|
return
|
||||||
|
@ -1045,12 +1077,12 @@ class ServiceRegistry:
|
||||||
await service_handler.func(service_call)
|
await service_handler.func(service_call)
|
||||||
fire_service_executed()
|
fire_service_executed()
|
||||||
else:
|
else:
|
||||||
def execute_service():
|
def execute_service() -> None:
|
||||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||||
service_handler.func(service_call)
|
service_handler.func(service_call)
|
||||||
fire_service_executed()
|
fire_service_executed()
|
||||||
|
|
||||||
await self._hass.async_add_job(execute_service)
|
await self._hass.async_add_executor_job(execute_service)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception('Error executing service %s', service_call)
|
_LOGGER.exception('Error executing service %s', service_call)
|
||||||
|
|
||||||
|
@ -1058,13 +1090,13 @@ class ServiceRegistry:
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration settings for Home Assistant."""
|
"""Configuration settings for Home Assistant."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""Initialize a new config object."""
|
"""Initialize a new config object."""
|
||||||
self.latitude = None # type: Optional[float]
|
self.latitude = None # type: Optional[float]
|
||||||
self.longitude = None # type: Optional[float]
|
self.longitude = None # type: Optional[float]
|
||||||
self.elevation = None # type: Optional[int]
|
self.elevation = None # type: Optional[int]
|
||||||
self.location_name = None # type: Optional[str]
|
self.location_name = None # type: Optional[str]
|
||||||
self.time_zone = None # type: Optional[str]
|
self.time_zone = None # type: Optional[datetime.tzinfo]
|
||||||
self.units = METRIC_SYSTEM # type: UnitSystem
|
self.units = METRIC_SYSTEM # type: UnitSystem
|
||||||
|
|
||||||
# If True, pip install is skipped for requirements on startup
|
# If True, pip install is skipped for requirements on startup
|
||||||
|
@ -1090,7 +1122,7 @@ class Config:
|
||||||
return self.units.length(
|
return self.units.length(
|
||||||
location.distance(self.latitude, self.longitude, lat, lon), 'm')
|
location.distance(self.latitude, self.longitude, lat, lon), 'm')
|
||||||
|
|
||||||
def path(self, *path):
|
def path(self, *path: str) -> str:
|
||||||
"""Generate path to the file within the configuration directory.
|
"""Generate path to the file within the configuration directory.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -1122,12 +1154,14 @@ class Config:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> Dict:
|
||||||
"""Create a dictionary representation of this dict.
|
"""Create a dictionary representation of this dict.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
time_zone = self.time_zone or dt_util.UTC
|
time_zone = dt_util.UTC.zone
|
||||||
|
if self.time_zone and getattr(self.time_zone, 'zone'):
|
||||||
|
time_zone = getattr(self.time_zone, 'zone')
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'latitude': self.latitude,
|
'latitude': self.latitude,
|
||||||
|
@ -1135,7 +1169,7 @@ class Config:
|
||||||
'elevation': self.elevation,
|
'elevation': self.elevation,
|
||||||
'unit_system': self.units.as_dict(),
|
'unit_system': self.units.as_dict(),
|
||||||
'location_name': self.location_name,
|
'location_name': self.location_name,
|
||||||
'time_zone': time_zone.zone,
|
'time_zone': time_zone,
|
||||||
'components': self.components,
|
'components': self.components,
|
||||||
'config_dir': self.config_dir,
|
'config_dir': self.config_dir,
|
||||||
'whitelist_external_dirs': self.whitelist_external_dirs,
|
'whitelist_external_dirs': self.whitelist_external_dirs,
|
||||||
|
@ -1143,12 +1177,12 @@ class Config:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _async_create_timer(hass):
|
def _async_create_timer(hass: HomeAssistant) -> None:
|
||||||
"""Create a timer that will start on HOMEASSISTANT_START."""
|
"""Create a timer that will start on HOMEASSISTANT_START."""
|
||||||
handle = None
|
handle = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def fire_time_event(nxt):
|
def fire_time_event(nxt: float) -> None:
|
||||||
"""Fire next time event."""
|
"""Fire next time event."""
|
||||||
nonlocal handle
|
nonlocal handle
|
||||||
|
|
||||||
|
@ -1165,7 +1199,7 @@ def _async_create_timer(hass):
|
||||||
handle = hass.loop.call_later(slp_seconds, fire_time_event, nxt)
|
handle = hass.loop.call_later(slp_seconds, fire_time_event, nxt)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def stop_timer(event):
|
def stop_timer(_: Event) -> None:
|
||||||
"""Stop the timer."""
|
"""Stop the timer."""
|
||||||
if handle is not None:
|
if handle is not None:
|
||||||
handle.cancel()
|
handle.cancel()
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
"""Classes to help gather user submissions."""
|
"""Classes to help gather user submissions."""
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any # noqa pylint: disable=unused-import
|
import voluptuous as vol
|
||||||
from .core import callback
|
from typing import Dict, Any, Callable, List, Optional # noqa pylint: disable=unused-import
|
||||||
|
from .core import callback, HomeAssistant
|
||||||
from .exceptions import HomeAssistantError
|
from .exceptions import HomeAssistantError
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -35,7 +36,8 @@ class UnknownStep(FlowError):
|
||||||
class FlowManager:
|
class FlowManager:
|
||||||
"""Manage all the flows that are in progress."""
|
"""Manage all the flows that are in progress."""
|
||||||
|
|
||||||
def __init__(self, hass, async_create_flow, async_finish_flow):
|
def __init__(self, hass: HomeAssistant, async_create_flow: Callable,
|
||||||
|
async_finish_flow: Callable) -> None:
|
||||||
"""Initialize the flow manager."""
|
"""Initialize the flow manager."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._progress = {} # type: Dict[str, Any]
|
self._progress = {} # type: Dict[str, Any]
|
||||||
|
@ -43,7 +45,7 @@ class FlowManager:
|
||||||
self._async_finish_flow = async_finish_flow
|
self._async_finish_flow = async_finish_flow
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_progress(self):
|
def async_progress(self) -> List[Dict]:
|
||||||
"""Return the flows in progress."""
|
"""Return the flows in progress."""
|
||||||
return [{
|
return [{
|
||||||
'flow_id': flow.flow_id,
|
'flow_id': flow.flow_id,
|
||||||
|
@ -51,7 +53,8 @@ class FlowManager:
|
||||||
'source': flow.source,
|
'source': flow.source,
|
||||||
} for flow in self._progress.values()]
|
} for flow in self._progress.values()]
|
||||||
|
|
||||||
async def async_init(self, handler, *, source=SOURCE_USER, data=None):
|
async def async_init(self, handler: Callable, *, source: str = SOURCE_USER,
|
||||||
|
data: str = None) -> Any:
|
||||||
"""Start a configuration flow."""
|
"""Start a configuration flow."""
|
||||||
flow = await self._async_create_flow(handler, source=source, data=data)
|
flow = await self._async_create_flow(handler, source=source, data=data)
|
||||||
flow.hass = self.hass
|
flow.hass = self.hass
|
||||||
|
@ -67,7 +70,8 @@ class FlowManager:
|
||||||
|
|
||||||
return await self._async_handle_step(flow, step, data)
|
return await self._async_handle_step(flow, step, data)
|
||||||
|
|
||||||
async def async_configure(self, flow_id, user_input=None):
|
async def async_configure(
|
||||||
|
self, flow_id: str, user_input: str = None) -> Any:
|
||||||
"""Continue a configuration flow."""
|
"""Continue a configuration flow."""
|
||||||
flow = self._progress.get(flow_id)
|
flow = self._progress.get(flow_id)
|
||||||
|
|
||||||
|
@ -83,12 +87,13 @@ class FlowManager:
|
||||||
flow, step_id, user_input)
|
flow, step_id, user_input)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_abort(self, flow_id):
|
def async_abort(self, flow_id: str) -> None:
|
||||||
"""Abort a flow."""
|
"""Abort a flow."""
|
||||||
if self._progress.pop(flow_id, None) is None:
|
if self._progress.pop(flow_id, None) is None:
|
||||||
raise UnknownFlow
|
raise UnknownFlow
|
||||||
|
|
||||||
async def _async_handle_step(self, flow, step_id, user_input):
|
async def _async_handle_step(self, flow: Any, step_id: str,
|
||||||
|
user_input: Optional[str]) -> Dict:
|
||||||
"""Handle a step of a flow."""
|
"""Handle a step of a flow."""
|
||||||
method = "async_step_{}".format(step_id)
|
method = "async_step_{}".format(step_id)
|
||||||
|
|
||||||
|
@ -97,7 +102,7 @@ class FlowManager:
|
||||||
raise UnknownStep("Handler {} doesn't support step {}".format(
|
raise UnknownStep("Handler {} doesn't support step {}".format(
|
||||||
flow.__class__.__name__, step_id))
|
flow.__class__.__name__, step_id))
|
||||||
|
|
||||||
result = await getattr(flow, method)(user_input)
|
result = await getattr(flow, method)(user_input) # type: Dict
|
||||||
|
|
||||||
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
|
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
|
||||||
RESULT_TYPE_ABORT):
|
RESULT_TYPE_ABORT):
|
||||||
|
@ -133,8 +138,9 @@ class FlowHandler:
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_show_form(self, *, step_id, data_schema=None, errors=None,
|
def async_show_form(self, *, step_id: str, data_schema: vol.Schema = None,
|
||||||
description_placeholders=None):
|
errors: Dict = None,
|
||||||
|
description_placeholders: Dict = None) -> Dict:
|
||||||
"""Return the definition of a form to gather user input."""
|
"""Return the definition of a form to gather user input."""
|
||||||
return {
|
return {
|
||||||
'type': RESULT_TYPE_FORM,
|
'type': RESULT_TYPE_FORM,
|
||||||
|
@ -147,7 +153,7 @@ class FlowHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_entry(self, *, title, data):
|
def async_create_entry(self, *, title: str, data: Dict) -> Dict:
|
||||||
"""Finish config flow and create a config entry."""
|
"""Finish config flow and create a config entry."""
|
||||||
return {
|
return {
|
||||||
'version': self.VERSION,
|
'version': self.VERSION,
|
||||||
|
@ -160,7 +166,7 @@ class FlowHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_abort(self, *, reason):
|
def async_abort(self, *, reason: str) -> Dict:
|
||||||
"""Abort the config flow."""
|
"""Abort the config flow."""
|
||||||
return {
|
return {
|
||||||
'type': RESULT_TYPE_ABORT,
|
'type': RESULT_TYPE_ABORT,
|
||||||
|
|
|
@ -17,7 +17,7 @@ import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from typing import Optional, Set, TYPE_CHECKING # NOQA
|
from typing import Optional, Set, TYPE_CHECKING, Callable, Any, TypeVar # NOQA
|
||||||
|
|
||||||
from homeassistant.const import PLATFORM_FORMAT
|
from homeassistant.const import PLATFORM_FORMAT
|
||||||
from homeassistant.util import OrderedSet
|
from homeassistant.util import OrderedSet
|
||||||
|
@ -27,6 +27,8 @@ from homeassistant.util import OrderedSet
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from homeassistant.core import HomeAssistant # NOQA
|
from homeassistant.core import HomeAssistant # NOQA
|
||||||
|
|
||||||
|
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
|
||||||
|
|
||||||
PREPARED = False
|
PREPARED = False
|
||||||
|
|
||||||
DEPENDENCY_BLACKLIST = {'config'}
|
DEPENDENCY_BLACKLIST = {'config'}
|
||||||
|
@ -51,7 +53,8 @@ def set_component(hass, # type: HomeAssistant
|
||||||
cache[comp_name] = component
|
cache[comp_name] = component
|
||||||
|
|
||||||
|
|
||||||
def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]:
|
def get_platform(hass, # type: HomeAssistant
|
||||||
|
domain: str, platform: str) -> Optional[ModuleType]:
|
||||||
"""Try to load specified platform.
|
"""Try to load specified platform.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -59,7 +62,8 @@ def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]:
|
||||||
return get_component(hass, PLATFORM_FORMAT.format(domain, platform))
|
return get_component(hass, PLATFORM_FORMAT.format(domain, platform))
|
||||||
|
|
||||||
|
|
||||||
def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
|
def get_component(hass, # type: HomeAssistant
|
||||||
|
comp_or_platform: str) -> Optional[ModuleType]:
|
||||||
"""Try to load specified component.
|
"""Try to load specified component.
|
||||||
|
|
||||||
Looks in config dir first, then built-in components.
|
Looks in config dir first, then built-in components.
|
||||||
|
@ -73,6 +77,9 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
|
||||||
|
|
||||||
cache = hass.data.get(DATA_KEY)
|
cache = hass.data.get(DATA_KEY)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
|
if hass.config.config_dir is None:
|
||||||
|
_LOGGER.error("Can't load components - config dir is not set")
|
||||||
|
return None
|
||||||
# Only insert if it's not there (happens during tests)
|
# Only insert if it's not there (happens during tests)
|
||||||
if sys.path[0] != hass.config.config_dir:
|
if sys.path[0] != hass.config.config_dir:
|
||||||
sys.path.insert(0, hass.config.config_dir)
|
sys.path.insert(0, hass.config.config_dir)
|
||||||
|
@ -134,14 +141,38 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleWrapper:
|
||||||
|
"""Class to wrap a Python module and auto fill in hass argument."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hass, # type: HomeAssistant
|
||||||
|
module: ModuleType) -> None:
|
||||||
|
"""Initialize the module wrapper."""
|
||||||
|
self._hass = hass
|
||||||
|
self._module = module
|
||||||
|
|
||||||
|
def __getattr__(self, attr: str) -> Any:
|
||||||
|
"""Fetch an attribute."""
|
||||||
|
value = getattr(self._module, attr)
|
||||||
|
|
||||||
|
if hasattr(value, '__bind_hass'):
|
||||||
|
value = ft.partial(value, self._hass)
|
||||||
|
|
||||||
|
setattr(self, attr, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Components:
|
class Components:
|
||||||
"""Helper to load components."""
|
"""Helper to load components."""
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass # type: HomeAssistant
|
||||||
|
) -> None:
|
||||||
"""Initialize the Components class."""
|
"""Initialize the Components class."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
|
|
||||||
def __getattr__(self, comp_name):
|
def __getattr__(self, comp_name: str) -> ModuleWrapper:
|
||||||
"""Fetch a component."""
|
"""Fetch a component."""
|
||||||
component = get_component(self._hass, comp_name)
|
component = get_component(self._hass, comp_name)
|
||||||
if component is None:
|
if component is None:
|
||||||
|
@ -154,11 +185,14 @@ class Components:
|
||||||
class Helpers:
|
class Helpers:
|
||||||
"""Helper to load helpers."""
|
"""Helper to load helpers."""
|
||||||
|
|
||||||
def __init__(self, hass):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass # type: HomeAssistant
|
||||||
|
) -> None:
|
||||||
"""Initialize the Helpers class."""
|
"""Initialize the Helpers class."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
|
|
||||||
def __getattr__(self, helper_name):
|
def __getattr__(self, helper_name: str) -> ModuleWrapper:
|
||||||
"""Fetch a helper."""
|
"""Fetch a helper."""
|
||||||
helper = importlib.import_module(
|
helper = importlib.import_module(
|
||||||
'homeassistant.helpers.{}'.format(helper_name))
|
'homeassistant.helpers.{}'.format(helper_name))
|
||||||
|
@ -167,33 +201,14 @@ class Helpers:
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class ModuleWrapper:
|
def bind_hass(func: CALLABLE_T) -> CALLABLE_T:
|
||||||
"""Class to wrap a Python module and auto fill in hass argument."""
|
|
||||||
|
|
||||||
def __init__(self, hass, module):
|
|
||||||
"""Initialize the module wrapper."""
|
|
||||||
self._hass = hass
|
|
||||||
self._module = module
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
"""Fetch an attribute."""
|
|
||||||
value = getattr(self._module, attr)
|
|
||||||
|
|
||||||
if hasattr(value, '__bind_hass'):
|
|
||||||
value = ft.partial(value, self._hass)
|
|
||||||
|
|
||||||
setattr(self, attr, value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def bind_hass(func):
|
|
||||||
"""Decorate function to indicate that first argument is hass."""
|
"""Decorate function to indicate that first argument is hass."""
|
||||||
# pylint: disable=protected-access
|
setattr(func, '__bind_hass', True)
|
||||||
func.__bind_hass = True
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def load_order_component(hass, comp_name: str) -> OrderedSet:
|
def load_order_component(hass, # type: HomeAssistant
|
||||||
|
comp_name: str) -> OrderedSet:
|
||||||
"""Return an OrderedSet of components in the correct order of loading.
|
"""Return an OrderedSet of components in the correct order of loading.
|
||||||
|
|
||||||
Raises HomeAssistantError if a circular dependency is detected.
|
Raises HomeAssistantError if a circular dependency is detected.
|
||||||
|
@ -204,7 +219,8 @@ def load_order_component(hass, comp_name: str) -> OrderedSet:
|
||||||
return _load_order_component(hass, comp_name, OrderedSet(), set())
|
return _load_order_component(hass, comp_name, OrderedSet(), set())
|
||||||
|
|
||||||
|
|
||||||
def _load_order_component(hass, comp_name: str, load_order: OrderedSet,
|
def _load_order_component(hass, # type: HomeAssistant
|
||||||
|
comp_name: str, load_order: OrderedSet,
|
||||||
loading: Set) -> OrderedSet:
|
loading: Set) -> OrderedSet:
|
||||||
"""Recursive function to get load order of components.
|
"""Recursive function to get load order of components.
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,10 @@ Related Python bugs:
|
||||||
- https://bugs.python.org/issue26617
|
- https://bugs.python.org/issue26617
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def patch_weakref_tasks():
|
def patch_weakref_tasks() -> None:
|
||||||
"""Replace weakref.WeakSet to address Python 3 bug."""
|
"""Replace weakref.WeakSet to address Python 3 bug."""
|
||||||
# pylint: disable=no-self-use, protected-access, bare-except
|
# pylint: disable=no-self-use, protected-access, bare-except
|
||||||
import asyncio.tasks
|
import asyncio.tasks
|
||||||
|
@ -30,7 +31,7 @@ def patch_weakref_tasks():
|
||||||
class IgnoreCalls:
|
class IgnoreCalls:
|
||||||
"""Ignore add calls."""
|
"""Ignore add calls."""
|
||||||
|
|
||||||
def add(self, other):
|
def add(self, other: Any) -> None:
|
||||||
"""No-op add."""
|
"""No-op add."""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ def patch_weakref_tasks():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def disable_c_asyncio():
|
def disable_c_asyncio() -> None:
|
||||||
"""Disable using C implementation of asyncio.
|
"""Disable using C implementation of asyncio.
|
||||||
|
|
||||||
Required to be able to apply the weakref monkey patch.
|
Required to be able to apply the weakref monkey patch.
|
||||||
|
@ -53,12 +54,12 @@ def disable_c_asyncio():
|
||||||
|
|
||||||
PATH_TRIGGER = '_asyncio'
|
PATH_TRIGGER = '_asyncio'
|
||||||
|
|
||||||
def __init__(self, path_entry):
|
def __init__(self, path_entry: str) -> None:
|
||||||
if path_entry != self.PATH_TRIGGER:
|
if path_entry != self.PATH_TRIGGER:
|
||||||
raise ImportError()
|
raise ImportError()
|
||||||
return
|
return
|
||||||
|
|
||||||
def find_module(self, fullname, path=None):
|
def find_module(self, fullname: str, path: Any = None) -> None:
|
||||||
"""Find a module."""
|
"""Find a module."""
|
||||||
if fullname == self.PATH_TRIGGER:
|
if fullname == self.PATH_TRIGGER:
|
||||||
# We lint in Py35, exception is introduced in Py36
|
# We lint in Py35, exception is introduced in Py36
|
||||||
|
|
|
@ -13,7 +13,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Any, List
|
||||||
|
|
||||||
from aiohttp.hdrs import METH_GET, METH_POST, METH_DELETE, CONTENT_TYPE
|
from aiohttp.hdrs import METH_GET, METH_POST, METH_DELETE, CONTENT_TYPE
|
||||||
import requests
|
import requests
|
||||||
|
@ -62,7 +62,7 @@ class API:
|
||||||
if port is not None:
|
if port is not None:
|
||||||
self.base_url += ':{}'.format(port)
|
self.base_url += ':{}'.format(port)
|
||||||
|
|
||||||
self.status = None
|
self.status = None # type: Optional[APIStatus]
|
||||||
self._headers = {CONTENT_TYPE: CONTENT_TYPE_JSON}
|
self._headers = {CONTENT_TYPE: CONTENT_TYPE_JSON}
|
||||||
|
|
||||||
if api_password is not None:
|
if api_password is not None:
|
||||||
|
@ -75,20 +75,24 @@ class API:
|
||||||
|
|
||||||
return self.status == APIStatus.OK
|
return self.status == APIStatus.OK
|
||||||
|
|
||||||
def __call__(self, method, path, data=None, timeout=5):
|
def __call__(self, method: str, path: str, data: Dict = None,
|
||||||
|
timeout: int = 5) -> requests.Response:
|
||||||
"""Make a call to the Home Assistant API."""
|
"""Make a call to the Home Assistant API."""
|
||||||
if data is not None:
|
if data is None:
|
||||||
data = json.dumps(data, cls=JSONEncoder)
|
data_str = None
|
||||||
|
else:
|
||||||
|
data_str = json.dumps(data, cls=JSONEncoder)
|
||||||
|
|
||||||
url = urllib.parse.urljoin(self.base_url, path)
|
url = urllib.parse.urljoin(self.base_url, path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if method == METH_GET:
|
if method == METH_GET:
|
||||||
return requests.get(
|
return requests.get(
|
||||||
url, params=data, timeout=timeout, headers=self._headers)
|
url, params=data_str, timeout=timeout,
|
||||||
|
headers=self._headers)
|
||||||
|
|
||||||
return requests.request(
|
return requests.request(
|
||||||
method, url, data=data, timeout=timeout,
|
method, url, data=data_str, timeout=timeout,
|
||||||
headers=self._headers)
|
headers=self._headers)
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
|
@ -110,7 +114,7 @@ class JSONEncoder(json.JSONEncoder):
|
||||||
"""JSONEncoder that supports Home Assistant objects."""
|
"""JSONEncoder that supports Home Assistant objects."""
|
||||||
|
|
||||||
# pylint: disable=method-hidden
|
# pylint: disable=method-hidden
|
||||||
def default(self, o):
|
def default(self, o: Any) -> Any:
|
||||||
"""Convert Home Assistant objects.
|
"""Convert Home Assistant objects.
|
||||||
|
|
||||||
Hand other objects to the original method.
|
Hand other objects to the original method.
|
||||||
|
@ -125,7 +129,7 @@ class JSONEncoder(json.JSONEncoder):
|
||||||
return json.JSONEncoder.default(self, o)
|
return json.JSONEncoder.default(self, o)
|
||||||
|
|
||||||
|
|
||||||
def validate_api(api):
|
def validate_api(api: API) -> APIStatus:
|
||||||
"""Make a call to validate API."""
|
"""Make a call to validate API."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_GET, URL_API)
|
req = api(METH_GET, URL_API)
|
||||||
|
@ -142,12 +146,12 @@ def validate_api(api):
|
||||||
return APIStatus.CANNOT_CONNECT
|
return APIStatus.CANNOT_CONNECT
|
||||||
|
|
||||||
|
|
||||||
def get_event_listeners(api):
|
def get_event_listeners(api: API) -> Dict:
|
||||||
"""List of events that is being listened for."""
|
"""List of events that is being listened for."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_GET, URL_API_EVENTS)
|
req = api(METH_GET, URL_API_EVENTS)
|
||||||
|
|
||||||
return req.json() if req.status_code == 200 else {}
|
return req.json() if req.status_code == 200 else {} # type: ignore
|
||||||
|
|
||||||
except (HomeAssistantError, ValueError):
|
except (HomeAssistantError, ValueError):
|
||||||
# ValueError if req.json() can't parse the json
|
# ValueError if req.json() can't parse the json
|
||||||
|
@ -156,7 +160,7 @@ def get_event_listeners(api):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def fire_event(api, event_type, data=None):
|
def fire_event(api: API, event_type: str, data: Dict = None) -> None:
|
||||||
"""Fire an event at remote API."""
|
"""Fire an event at remote API."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_POST, URL_API_EVENTS_EVENT.format(event_type), data)
|
req = api(METH_POST, URL_API_EVENTS_EVENT.format(event_type), data)
|
||||||
|
@ -169,7 +173,7 @@ def fire_event(api, event_type, data=None):
|
||||||
_LOGGER.exception("Error firing event")
|
_LOGGER.exception("Error firing event")
|
||||||
|
|
||||||
|
|
||||||
def get_state(api, entity_id):
|
def get_state(api: API, entity_id: str) -> Optional[ha.State]:
|
||||||
"""Query given API for state of entity_id."""
|
"""Query given API for state of entity_id."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_GET, URL_API_STATES_ENTITY.format(entity_id))
|
req = api(METH_GET, URL_API_STATES_ENTITY.format(entity_id))
|
||||||
|
@ -186,7 +190,7 @@ def get_state(api, entity_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_states(api):
|
def get_states(api: API) -> List[ha.State]:
|
||||||
"""Query given API for all states."""
|
"""Query given API for all states."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_GET,
|
req = api(METH_GET,
|
||||||
|
@ -202,7 +206,7 @@ def get_states(api):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def remove_state(api, entity_id):
|
def remove_state(api: API, entity_id: str) -> bool:
|
||||||
"""Call API to remove state for entity_id.
|
"""Call API to remove state for entity_id.
|
||||||
|
|
||||||
Return True if entity is gone (removed/never existed).
|
Return True if entity is gone (removed/never existed).
|
||||||
|
@ -222,7 +226,8 @@ def remove_state(api, entity_id):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def set_state(api, entity_id, new_state, attributes=None, force_update=False):
|
def set_state(api: API, entity_id: str, new_state: str,
|
||||||
|
attributes: Dict = None, force_update: bool = False) -> bool:
|
||||||
"""Tell API to update state for entity_id.
|
"""Tell API to update state for entity_id.
|
||||||
|
|
||||||
Return True if success.
|
Return True if success.
|
||||||
|
@ -249,14 +254,14 @@ def set_state(api, entity_id, new_state, attributes=None, force_update=False):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_state(api, entity_id, state):
|
def is_state(api: API, entity_id: str, state: str) -> bool:
|
||||||
"""Query API to see if entity_id is specified state."""
|
"""Query API to see if entity_id is specified state."""
|
||||||
cur_state = get_state(api, entity_id)
|
cur_state = get_state(api, entity_id)
|
||||||
|
|
||||||
return cur_state and cur_state.state == state
|
return bool(cur_state and cur_state.state == state)
|
||||||
|
|
||||||
|
|
||||||
def get_services(api):
|
def get_services(api: API) -> Dict:
|
||||||
"""Return a list of dicts.
|
"""Return a list of dicts.
|
||||||
|
|
||||||
Each dict has a string "domain" and a list of strings "services".
|
Each dict has a string "domain" and a list of strings "services".
|
||||||
|
@ -264,7 +269,7 @@ def get_services(api):
|
||||||
try:
|
try:
|
||||||
req = api(METH_GET, URL_API_SERVICES)
|
req = api(METH_GET, URL_API_SERVICES)
|
||||||
|
|
||||||
return req.json() if req.status_code == 200 else {}
|
return req.json() if req.status_code == 200 else {} # type: ignore
|
||||||
|
|
||||||
except (HomeAssistantError, ValueError):
|
except (HomeAssistantError, ValueError):
|
||||||
# ValueError if req.json() can't parse the json
|
# ValueError if req.json() can't parse the json
|
||||||
|
@ -273,7 +278,9 @@ def get_services(api):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def call_service(api, domain, service, service_data=None, timeout=5):
|
def call_service(api: API, domain: str, service: str,
|
||||||
|
service_data: Dict = None,
|
||||||
|
timeout: int = 5) -> None:
|
||||||
"""Call a service at the remote API."""
|
"""Call a service at the remote API."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_POST,
|
req = api(METH_POST,
|
||||||
|
@ -288,7 +295,7 @@ def call_service(api, domain, service, service_data=None, timeout=5):
|
||||||
_LOGGER.exception("Error calling service")
|
_LOGGER.exception("Error calling service")
|
||||||
|
|
||||||
|
|
||||||
def get_config(api):
|
def get_config(api: API) -> Dict:
|
||||||
"""Return configuration."""
|
"""Return configuration."""
|
||||||
try:
|
try:
|
||||||
req = api(METH_GET, URL_API_CONFIG)
|
req = api(METH_GET, URL_API_CONFIG)
|
||||||
|
@ -299,7 +306,7 @@ def get_config(api):
|
||||||
result = req.json()
|
result = req.json()
|
||||||
if 'components' in result:
|
if 'components' in result:
|
||||||
result['components'] = set(result['components'])
|
result['components'] = set(result['components'])
|
||||||
return result
|
return result # type: ignore
|
||||||
|
|
||||||
except (HomeAssistantError, ValueError):
|
except (HomeAssistantError, ValueError):
|
||||||
# ValueError if req.json() can't parse the JSON
|
# ValueError if req.json() can't parse the JSON
|
||||||
|
|
|
@ -3,15 +3,18 @@ import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
import homeassistant.util.package as pkg_util
|
import homeassistant.util.package as pkg_util
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
DATA_PIP_LOCK = 'pip_lock'
|
DATA_PIP_LOCK = 'pip_lock'
|
||||||
CONSTRAINT_FILE = 'package_constraints.txt'
|
CONSTRAINT_FILE = 'package_constraints.txt'
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_process_requirements(hass, name, requirements):
|
async def async_process_requirements(hass: HomeAssistant, name: str,
|
||||||
|
requirements: List[str]) -> bool:
|
||||||
"""Install the requirements for a component or platform.
|
"""Install the requirements for a component or platform.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -25,7 +28,7 @@ async def async_process_requirements(hass, name, requirements):
|
||||||
|
|
||||||
async with pip_lock:
|
async with pip_lock:
|
||||||
for req in requirements:
|
for req in requirements:
|
||||||
ret = await hass.async_add_job(pip_install, req)
|
ret = await hass.async_add_executor_job(pip_install, req)
|
||||||
if not ret:
|
if not ret:
|
||||||
_LOGGER.error("Not initializing %s because could not install "
|
_LOGGER.error("Not initializing %s because could not install "
|
||||||
"requirement %s", name, req)
|
"requirement %s", name, req)
|
||||||
|
@ -34,11 +37,11 @@ async def async_process_requirements(hass, name, requirements):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def pip_kwargs(config_dir):
|
def pip_kwargs(config_dir: Optional[str]) -> Dict[str, str]:
|
||||||
"""Return keyword arguments for PIP install."""
|
"""Return keyword arguments for PIP install."""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'constraints': os.path.join(os.path.dirname(__file__), CONSTRAINT_FILE)
|
'constraints': os.path.join(os.path.dirname(__file__), CONSTRAINT_FILE)
|
||||||
}
|
}
|
||||||
if not pkg_util.is_virtual_env():
|
if not (config_dir is None or pkg_util.is_virtual_env()):
|
||||||
kwargs['target'] = os.path.join(config_dir, 'deps')
|
kwargs['target'] = os.path.join(config_dir, 'deps')
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
|
@ -4,7 +4,7 @@ import logging.handlers
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
from homeassistant import requirements, core, loader, config as conf_util
|
from homeassistant import requirements, core, loader, config as conf_util
|
||||||
from homeassistant.config import async_notify_setup_error
|
from homeassistant.config import async_notify_setup_error
|
||||||
|
@ -56,7 +56,9 @@ async def async_setup_component(hass: core.HomeAssistant, domain: str,
|
||||||
return await task # type: ignore
|
return await task # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def _async_process_dependencies(hass, config, name, dependencies):
|
async def _async_process_dependencies(
|
||||||
|
hass: core.HomeAssistant, config: Dict, name: str,
|
||||||
|
dependencies: List[str]) -> bool:
|
||||||
"""Ensure all dependencies are set up."""
|
"""Ensure all dependencies are set up."""
|
||||||
blacklisted = [dep for dep in dependencies
|
blacklisted = [dep for dep in dependencies
|
||||||
if dep in loader.DEPENDENCY_BLACKLIST]
|
if dep in loader.DEPENDENCY_BLACKLIST]
|
||||||
|
@ -88,12 +90,12 @@ async def _async_process_dependencies(hass, config, name, dependencies):
|
||||||
|
|
||||||
|
|
||||||
async def _async_setup_component(hass: core.HomeAssistant,
|
async def _async_setup_component(hass: core.HomeAssistant,
|
||||||
domain: str, config) -> bool:
|
domain: str, config: Dict) -> bool:
|
||||||
"""Set up a component for Home Assistant.
|
"""Set up a component for Home Assistant.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
def log_error(msg, link=True):
|
def log_error(msg: str, link: bool = True) -> None:
|
||||||
"""Log helper."""
|
"""Log helper."""
|
||||||
_LOGGER.error("Setup failed for %s: %s", domain, msg)
|
_LOGGER.error("Setup failed for %s: %s", domain, msg)
|
||||||
async_notify_setup_error(hass, domain, link)
|
async_notify_setup_error(hass, domain, link)
|
||||||
|
@ -181,7 +183,7 @@ async def _async_setup_component(hass: core.HomeAssistant,
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
|
async def async_prepare_setup_platform(hass: core.HomeAssistant, config: Dict,
|
||||||
domain: str, platform_name: str) \
|
domain: str, platform_name: str) \
|
||||||
-> Optional[ModuleType]:
|
-> Optional[ModuleType]:
|
||||||
"""Load a platform and makes sure dependencies are setup.
|
"""Load a platform and makes sure dependencies are setup.
|
||||||
|
@ -190,7 +192,7 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
|
||||||
"""
|
"""
|
||||||
platform_path = PLATFORM_FORMAT.format(domain, platform_name)
|
platform_path = PLATFORM_FORMAT.format(domain, platform_name)
|
||||||
|
|
||||||
def log_error(msg):
|
def log_error(msg: str) -> None:
|
||||||
"""Log helper."""
|
"""Log helper."""
|
||||||
_LOGGER.error("Unable to prepare setup for platform %s: %s",
|
_LOGGER.error("Unable to prepare setup for platform %s: %s",
|
||||||
platform_path, msg)
|
platform_path, msg)
|
||||||
|
@ -217,7 +219,9 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
|
||||||
return platform
|
return platform
|
||||||
|
|
||||||
|
|
||||||
async def async_process_deps_reqs(hass, config, name, module):
|
async def async_process_deps_reqs(
|
||||||
|
hass: core.HomeAssistant, config: Dict, name: str,
|
||||||
|
module: ModuleType) -> None:
|
||||||
"""Process all dependencies and requirements for a module.
|
"""Process all dependencies and requirements for a module.
|
||||||
|
|
||||||
Module is a Python module of either a component or platform.
|
Module is a Python module of either a component or platform.
|
||||||
|
@ -231,14 +235,14 @@ async def async_process_deps_reqs(hass, config, name, module):
|
||||||
|
|
||||||
if hasattr(module, 'DEPENDENCIES'):
|
if hasattr(module, 'DEPENDENCIES'):
|
||||||
dep_success = await _async_process_dependencies(
|
dep_success = await _async_process_dependencies(
|
||||||
hass, config, name, module.DEPENDENCIES)
|
hass, config, name, module.DEPENDENCIES) # type: ignore
|
||||||
|
|
||||||
if not dep_success:
|
if not dep_success:
|
||||||
raise HomeAssistantError("Could not setup all dependencies.")
|
raise HomeAssistantError("Could not setup all dependencies.")
|
||||||
|
|
||||||
if not hass.config.skip_pip and hasattr(module, 'REQUIREMENTS'):
|
if not hass.config.skip_pip and hasattr(module, 'REQUIREMENTS'):
|
||||||
req_success = await requirements.async_process_requirements(
|
req_success = await requirements.async_process_requirements(
|
||||||
hass, name, module.REQUIREMENTS)
|
hass, name, module.REQUIREMENTS) # type: ignore
|
||||||
|
|
||||||
if not req_success:
|
if not req_success:
|
||||||
raise HomeAssistantError("Could not install all requirements.")
|
raise HomeAssistantError("Could not install all requirements.")
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
"""Helper methods for various modules."""
|
"""Helper methods for various modules."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import MutableSet
|
from datetime import datetime, timedelta
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
|
||||||
import re
|
import re
|
||||||
import enum
|
import enum
|
||||||
import socket
|
import socket
|
||||||
|
@ -14,12 +13,13 @@ from types import MappingProxyType
|
||||||
from unicodedata import normalize
|
from unicodedata import normalize
|
||||||
|
|
||||||
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa
|
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa
|
||||||
Iterable, List, Mapping)
|
Iterable, List, Dict, Iterator, Coroutine, MutableSet)
|
||||||
|
|
||||||
from .dt import as_local, utcnow
|
from .dt import as_local, utcnow
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
U = TypeVar('U')
|
U = TypeVar('U')
|
||||||
|
ENUM_T = TypeVar('ENUM_T', bound=enum.Enum)
|
||||||
|
|
||||||
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
|
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
|
||||||
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
|
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
|
||||||
|
@ -91,7 +91,7 @@ def ensure_unique_string(preferred_string: str, current_strings:
|
||||||
|
|
||||||
|
|
||||||
# Taken from: http://stackoverflow.com/a/11735897
|
# Taken from: http://stackoverflow.com/a/11735897
|
||||||
def get_local_ip():
|
def get_local_ip() -> str:
|
||||||
"""Try to determine the local IP address of the machine."""
|
"""Try to determine the local IP address of the machine."""
|
||||||
try:
|
try:
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
@ -99,7 +99,7 @@ def get_local_ip():
|
||||||
# Use Google Public DNS server to determine own IP
|
# Use Google Public DNS server to determine own IP
|
||||||
sock.connect(('8.8.8.8', 80))
|
sock.connect(('8.8.8.8', 80))
|
||||||
|
|
||||||
return sock.getsockname()[0]
|
return sock.getsockname()[0] # type: ignore
|
||||||
except socket.error:
|
except socket.error:
|
||||||
try:
|
try:
|
||||||
return socket.gethostbyname(socket.gethostname())
|
return socket.gethostbyname(socket.gethostname())
|
||||||
|
@ -110,7 +110,7 @@ def get_local_ip():
|
||||||
|
|
||||||
|
|
||||||
# Taken from http://stackoverflow.com/a/23728630
|
# Taken from http://stackoverflow.com/a/23728630
|
||||||
def get_random_string(length=10):
|
def get_random_string(length: int = 10) -> str:
|
||||||
"""Return a random string with letters and digits."""
|
"""Return a random string with letters and digits."""
|
||||||
generator = random.SystemRandom()
|
generator = random.SystemRandom()
|
||||||
source_chars = string.ascii_letters + string.digits
|
source_chars = string.ascii_letters + string.digits
|
||||||
|
@ -121,59 +121,59 @@ def get_random_string(length=10):
|
||||||
class OrderedEnum(enum.Enum):
|
class OrderedEnum(enum.Enum):
|
||||||
"""Taken from Python 3.4.0 docs."""
|
"""Taken from Python 3.4.0 docs."""
|
||||||
|
|
||||||
def __ge__(self, other):
|
def __ge__(self: ENUM_T, other: ENUM_T) -> bool:
|
||||||
"""Return the greater than element."""
|
"""Return the greater than element."""
|
||||||
if self.__class__ is other.__class__:
|
if self.__class__ is other.__class__:
|
||||||
return self.value >= other.value
|
return bool(self.value >= other.value)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __gt__(self, other):
|
def __gt__(self: ENUM_T, other: ENUM_T) -> bool:
|
||||||
"""Return the greater element."""
|
"""Return the greater element."""
|
||||||
if self.__class__ is other.__class__:
|
if self.__class__ is other.__class__:
|
||||||
return self.value > other.value
|
return bool(self.value > other.value)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __le__(self, other):
|
def __le__(self: ENUM_T, other: ENUM_T) -> bool:
|
||||||
"""Return the lower than element."""
|
"""Return the lower than element."""
|
||||||
if self.__class__ is other.__class__:
|
if self.__class__ is other.__class__:
|
||||||
return self.value <= other.value
|
return bool(self.value <= other.value)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self: ENUM_T, other: ENUM_T) -> bool:
|
||||||
"""Return the lower element."""
|
"""Return the lower element."""
|
||||||
if self.__class__ is other.__class__:
|
if self.__class__ is other.__class__:
|
||||||
return self.value < other.value
|
return bool(self.value < other.value)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
class OrderedSet(MutableSet):
|
class OrderedSet(MutableSet[T]):
|
||||||
"""Ordered set taken from http://code.activestate.com/recipes/576694/."""
|
"""Ordered set taken from http://code.activestate.com/recipes/576694/."""
|
||||||
|
|
||||||
def __init__(self, iterable=None):
|
def __init__(self, iterable: Iterable[T] = None) -> None:
|
||||||
"""Initialize the set."""
|
"""Initialize the set."""
|
||||||
self.end = end = [] # type: List[Any]
|
self.end = end = [] # type: List[Any]
|
||||||
end += [None, end, end] # sentinel node for doubly linked list
|
end += [None, end, end] # sentinel node for doubly linked list
|
||||||
self.map = {} # type: Mapping[List, Any] # key --> [key, prev, next]
|
self.map = {} # type: Dict[T, List] # key --> [key, prev, next]
|
||||||
if iterable is not None:
|
if iterable is not None:
|
||||||
self |= iterable
|
self |= iterable # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
"""Return the length of the set."""
|
"""Return the length of the set."""
|
||||||
return len(self.map)
|
return len(self.map)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key: T) -> bool: # type: ignore
|
||||||
"""Check if key is in set."""
|
"""Check if key is in set."""
|
||||||
return key in self.map
|
return key in self.map
|
||||||
|
|
||||||
# pylint: disable=arguments-differ
|
# pylint: disable=arguments-differ
|
||||||
def add(self, key):
|
def add(self, key: T) -> None:
|
||||||
"""Add an element to the end of the set."""
|
"""Add an element to the end of the set."""
|
||||||
if key not in self.map:
|
if key not in self.map:
|
||||||
end = self.end
|
end = self.end
|
||||||
curr = end[1]
|
curr = end[1]
|
||||||
curr[2] = end[1] = self.map[key] = [key, curr, end]
|
curr[2] = end[1] = self.map[key] = [key, curr, end]
|
||||||
|
|
||||||
def promote(self, key):
|
def promote(self, key: T) -> None:
|
||||||
"""Promote element to beginning of the set, add if not there."""
|
"""Promote element to beginning of the set, add if not there."""
|
||||||
if key in self.map:
|
if key in self.map:
|
||||||
self.discard(key)
|
self.discard(key)
|
||||||
|
@ -183,14 +183,14 @@ class OrderedSet(MutableSet):
|
||||||
curr[2] = begin[1] = self.map[key] = [key, curr, begin]
|
curr[2] = begin[1] = self.map[key] = [key, curr, begin]
|
||||||
|
|
||||||
# pylint: disable=arguments-differ
|
# pylint: disable=arguments-differ
|
||||||
def discard(self, key):
|
def discard(self, key: T) -> None:
|
||||||
"""Discard an element from the set."""
|
"""Discard an element from the set."""
|
||||||
if key in self.map:
|
if key in self.map:
|
||||||
key, prev_item, next_item = self.map.pop(key)
|
key, prev_item, next_item = self.map.pop(key)
|
||||||
prev_item[2] = next_item
|
prev_item[2] = next_item
|
||||||
next_item[1] = prev_item
|
next_item[1] = prev_item
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[T]:
|
||||||
"""Iterate of the set."""
|
"""Iterate of the set."""
|
||||||
end = self.end
|
end = self.end
|
||||||
curr = end[2]
|
curr = end[2]
|
||||||
|
@ -198,7 +198,7 @@ class OrderedSet(MutableSet):
|
||||||
yield curr[0]
|
yield curr[0]
|
||||||
curr = curr[2]
|
curr = curr[2]
|
||||||
|
|
||||||
def __reversed__(self):
|
def __reversed__(self) -> Iterator[T]:
|
||||||
"""Reverse the ordering."""
|
"""Reverse the ordering."""
|
||||||
end = self.end
|
end = self.end
|
||||||
curr = end[1]
|
curr = end[1]
|
||||||
|
@ -207,7 +207,7 @@ class OrderedSet(MutableSet):
|
||||||
curr = curr[1]
|
curr = curr[1]
|
||||||
|
|
||||||
# pylint: disable=arguments-differ
|
# pylint: disable=arguments-differ
|
||||||
def pop(self, last=True):
|
def pop(self, last: bool = True) -> T:
|
||||||
"""Pop element of the end of the set.
|
"""Pop element of the end of the set.
|
||||||
|
|
||||||
Set last=False to pop from the beginning.
|
Set last=False to pop from the beginning.
|
||||||
|
@ -216,20 +216,20 @@ class OrderedSet(MutableSet):
|
||||||
raise KeyError('set is empty')
|
raise KeyError('set is empty')
|
||||||
key = self.end[1][0] if last else self.end[2][0]
|
key = self.end[1][0] if last else self.end[2][0]
|
||||||
self.discard(key)
|
self.discard(key)
|
||||||
return key
|
return key # type: ignore
|
||||||
|
|
||||||
def update(self, *args):
|
def update(self, *args: Any) -> None:
|
||||||
"""Add elements from args to the set."""
|
"""Add elements from args to the set."""
|
||||||
for item in chain(*args):
|
for item in chain(*args):
|
||||||
self.add(item)
|
self.add(item)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Return the representation."""
|
"""Return the representation."""
|
||||||
if not self:
|
if not self:
|
||||||
return '%s()' % (self.__class__.__name__,)
|
return '%s()' % (self.__class__.__name__,)
|
||||||
return '%s(%r)' % (self.__class__.__name__, list(self))
|
return '%s(%r)' % (self.__class__.__name__, list(self))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Return the comparison."""
|
"""Return the comparison."""
|
||||||
if isinstance(other, OrderedSet):
|
if isinstance(other, OrderedSet):
|
||||||
return len(self) == len(other) and list(self) == list(other)
|
return len(self) == len(other) and list(self) == list(other)
|
||||||
|
@ -254,20 +254,21 @@ class Throttle:
|
||||||
Adds a datetime attribute `last_call` to the method.
|
Adds a datetime attribute `last_call` to the method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, min_time, limit_no_throttle=None):
|
def __init__(self, min_time: timedelta,
|
||||||
|
limit_no_throttle: timedelta = None) -> None:
|
||||||
"""Initialize the throttle."""
|
"""Initialize the throttle."""
|
||||||
self.min_time = min_time
|
self.min_time = min_time
|
||||||
self.limit_no_throttle = limit_no_throttle
|
self.limit_no_throttle = limit_no_throttle
|
||||||
|
|
||||||
def __call__(self, method):
|
def __call__(self, method: Callable) -> Callable:
|
||||||
"""Caller for the throttle."""
|
"""Caller for the throttle."""
|
||||||
# Make sure we return a coroutine if the method is async.
|
# Make sure we return a coroutine if the method is async.
|
||||||
if asyncio.iscoroutinefunction(method):
|
if asyncio.iscoroutinefunction(method):
|
||||||
async def throttled_value():
|
async def throttled_value() -> None:
|
||||||
"""Stand-in function for when real func is being throttled."""
|
"""Stand-in function for when real func is being throttled."""
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
def throttled_value():
|
def throttled_value() -> None: # type: ignore
|
||||||
"""Stand-in function for when real func is being throttled."""
|
"""Stand-in function for when real func is being throttled."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -288,14 +289,14 @@ class Throttle:
|
||||||
'.' not in method.__qualname__.split('.<locals>.')[-1])
|
'.' not in method.__qualname__.split('.<locals>.')[-1])
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
|
||||||
"""Wrap that allows wrapped to be called only once per min_time.
|
"""Wrap that allows wrapped to be called only once per min_time.
|
||||||
|
|
||||||
If we cannot acquire the lock, it is running so return None.
|
If we cannot acquire the lock, it is running so return None.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if hasattr(method, '__self__'):
|
if hasattr(method, '__self__'):
|
||||||
host = method.__self__
|
host = getattr(method, '__self__')
|
||||||
elif is_func:
|
elif is_func:
|
||||||
host = wrapper
|
host = wrapper
|
||||||
else:
|
else:
|
||||||
|
@ -318,7 +319,7 @@ class Throttle:
|
||||||
if force or utcnow() - throttle[1] > self.min_time:
|
if force or utcnow() - throttle[1] > self.min_time:
|
||||||
result = method(*args, **kwargs)
|
result = method(*args, **kwargs)
|
||||||
throttle[1] = utcnow()
|
throttle[1] = utcnow()
|
||||||
return result
|
return result # type: ignore
|
||||||
|
|
||||||
return throttled_value()
|
return throttled_value()
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -3,22 +3,25 @@ import concurrent.futures
|
||||||
import threading
|
import threading
|
||||||
import logging
|
import logging
|
||||||
from asyncio import coroutines
|
from asyncio import coroutines
|
||||||
|
from asyncio.events import AbstractEventLoop
|
||||||
from asyncio.futures import Future
|
from asyncio.futures import Future
|
||||||
|
|
||||||
from asyncio import ensure_future
|
from asyncio import ensure_future
|
||||||
|
from typing import Any, Union, Coroutine, Callable, Generator
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _set_result_unless_cancelled(fut, result):
|
def _set_result_unless_cancelled(fut: Future, result: Any) -> None:
|
||||||
"""Set the result only if the Future was not cancelled."""
|
"""Set the result only if the Future was not cancelled."""
|
||||||
if fut.cancelled():
|
if fut.cancelled():
|
||||||
return
|
return
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
|
|
||||||
def _set_concurrent_future_state(concurr, source):
|
def _set_concurrent_future_state(
|
||||||
|
concurr: concurrent.futures.Future,
|
||||||
|
source: Union[concurrent.futures.Future, Future]) -> None:
|
||||||
"""Copy state from a future to a concurrent.futures.Future."""
|
"""Copy state from a future to a concurrent.futures.Future."""
|
||||||
assert source.done()
|
assert source.done()
|
||||||
if source.cancelled():
|
if source.cancelled():
|
||||||
|
@ -33,7 +36,8 @@ def _set_concurrent_future_state(concurr, source):
|
||||||
concurr.set_result(result)
|
concurr.set_result(result)
|
||||||
|
|
||||||
|
|
||||||
def _copy_future_state(source, dest):
|
def _copy_future_state(source: Union[concurrent.futures.Future, Future],
|
||||||
|
dest: Union[concurrent.futures.Future, Future]) -> None:
|
||||||
"""Copy state from another Future.
|
"""Copy state from another Future.
|
||||||
|
|
||||||
The other Future may be a concurrent.futures.Future.
|
The other Future may be a concurrent.futures.Future.
|
||||||
|
@ -53,7 +57,9 @@ def _copy_future_state(source, dest):
|
||||||
dest.set_result(result)
|
dest.set_result(result)
|
||||||
|
|
||||||
|
|
||||||
def _chain_future(source, destination):
|
def _chain_future(
|
||||||
|
source: Union[concurrent.futures.Future, Future],
|
||||||
|
destination: Union[concurrent.futures.Future, Future]) -> None:
|
||||||
"""Chain two futures so that when one completes, so does the other.
|
"""Chain two futures so that when one completes, so does the other.
|
||||||
|
|
||||||
The result (or exception) of source will be copied to destination.
|
The result (or exception) of source will be copied to destination.
|
||||||
|
@ -74,20 +80,23 @@ def _chain_future(source, destination):
|
||||||
else:
|
else:
|
||||||
dest_loop = None
|
dest_loop = None
|
||||||
|
|
||||||
def _set_state(future, other):
|
def _set_state(future: Union[concurrent.futures.Future, Future],
|
||||||
|
other: Union[concurrent.futures.Future, Future]) -> None:
|
||||||
if isinstance(future, Future):
|
if isinstance(future, Future):
|
||||||
_copy_future_state(other, future)
|
_copy_future_state(other, future)
|
||||||
else:
|
else:
|
||||||
_set_concurrent_future_state(future, other)
|
_set_concurrent_future_state(future, other)
|
||||||
|
|
||||||
def _call_check_cancel(destination):
|
def _call_check_cancel(
|
||||||
|
destination: Union[concurrent.futures.Future, Future]) -> None:
|
||||||
if destination.cancelled():
|
if destination.cancelled():
|
||||||
if source_loop is None or source_loop is dest_loop:
|
if source_loop is None or source_loop is dest_loop:
|
||||||
source.cancel()
|
source.cancel()
|
||||||
else:
|
else:
|
||||||
source_loop.call_soon_threadsafe(source.cancel)
|
source_loop.call_soon_threadsafe(source.cancel)
|
||||||
|
|
||||||
def _call_set_state(source):
|
def _call_set_state(
|
||||||
|
source: Union[concurrent.futures.Future, Future]) -> None:
|
||||||
if dest_loop is None or dest_loop is source_loop:
|
if dest_loop is None or dest_loop is source_loop:
|
||||||
_set_state(destination, source)
|
_set_state(destination, source)
|
||||||
else:
|
else:
|
||||||
|
@ -97,7 +106,9 @@ def _chain_future(source, destination):
|
||||||
source.add_done_callback(_call_set_state)
|
source.add_done_callback(_call_set_state)
|
||||||
|
|
||||||
|
|
||||||
def run_coroutine_threadsafe(coro, loop):
|
def run_coroutine_threadsafe(
|
||||||
|
coro: Union[Coroutine, Generator],
|
||||||
|
loop: AbstractEventLoop) -> concurrent.futures.Future:
|
||||||
"""Submit a coroutine object to a given event loop.
|
"""Submit a coroutine object to a given event loop.
|
||||||
|
|
||||||
Return a concurrent.futures.Future to access the result.
|
Return a concurrent.futures.Future to access the result.
|
||||||
|
@ -110,7 +121,7 @@ def run_coroutine_threadsafe(coro, loop):
|
||||||
raise TypeError('A coroutine object is required')
|
raise TypeError('A coroutine object is required')
|
||||||
future = concurrent.futures.Future() # type: concurrent.futures.Future
|
future = concurrent.futures.Future() # type: concurrent.futures.Future
|
||||||
|
|
||||||
def callback():
|
def callback() -> None:
|
||||||
"""Handle the call to the coroutine."""
|
"""Handle the call to the coroutine."""
|
||||||
try:
|
try:
|
||||||
_chain_future(ensure_future(coro, loop=loop), future)
|
_chain_future(ensure_future(coro, loop=loop), future)
|
||||||
|
@ -125,7 +136,8 @@ def run_coroutine_threadsafe(coro, loop):
|
||||||
return future
|
return future
|
||||||
|
|
||||||
|
|
||||||
def fire_coroutine_threadsafe(coro, loop):
|
def fire_coroutine_threadsafe(coro: Coroutine,
|
||||||
|
loop: AbstractEventLoop) -> None:
|
||||||
"""Submit a coroutine object to a given event loop.
|
"""Submit a coroutine object to a given event loop.
|
||||||
|
|
||||||
This method does not provide a way to retrieve the result and
|
This method does not provide a way to retrieve the result and
|
||||||
|
@ -139,7 +151,7 @@ def fire_coroutine_threadsafe(coro, loop):
|
||||||
if not coroutines.iscoroutine(coro):
|
if not coroutines.iscoroutine(coro):
|
||||||
raise TypeError('A coroutine object is required: %s' % coro)
|
raise TypeError('A coroutine object is required: %s' % coro)
|
||||||
|
|
||||||
def callback():
|
def callback() -> None:
|
||||||
"""Handle the firing of a coroutine."""
|
"""Handle the firing of a coroutine."""
|
||||||
ensure_future(coro, loop=loop)
|
ensure_future(coro, loop=loop)
|
||||||
|
|
||||||
|
@ -147,7 +159,8 @@ def fire_coroutine_threadsafe(coro, loop):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def run_callback_threadsafe(loop, callback, *args):
|
def run_callback_threadsafe(loop: AbstractEventLoop, callback: Callable,
|
||||||
|
*args: Any) -> concurrent.futures.Future:
|
||||||
"""Submit a callback object to a given event loop.
|
"""Submit a callback object to a given event loop.
|
||||||
|
|
||||||
Return a concurrent.futures.Future to access the result.
|
Return a concurrent.futures.Future to access the result.
|
||||||
|
@ -158,7 +171,7 @@ def run_callback_threadsafe(loop, callback, *args):
|
||||||
|
|
||||||
future = concurrent.futures.Future() # type: concurrent.futures.Future
|
future = concurrent.futures.Future() # type: concurrent.futures.Future
|
||||||
|
|
||||||
def run_callback():
|
def run_callback() -> None:
|
||||||
"""Run callback and store result."""
|
"""Run callback and store result."""
|
||||||
try:
|
try:
|
||||||
future.set_result(callback(*args))
|
future.set_result(callback(*args))
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import math
|
import math
|
||||||
import colorsys
|
import colorsys
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple, List
|
||||||
|
|
||||||
# Official CSS3 colors from w3.org:
|
# Official CSS3 colors from w3.org:
|
||||||
# https://www.w3.org/TR/2010/PR-css3-color-20101028/#html4
|
# https://www.w3.org/TR/2010/PR-css3-color-20101028/#html4
|
||||||
|
@ -162,7 +162,7 @@ COLORS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def color_name_to_rgb(color_name):
|
def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
|
||||||
"""Convert color name to RGB hex value."""
|
"""Convert color name to RGB hex value."""
|
||||||
# COLORS map has no spaces in it, so make the color_name have no
|
# COLORS map has no spaces in it, so make the color_name have no
|
||||||
# spaces in it as well for matching purposes
|
# spaces in it as well for matching purposes
|
||||||
|
@ -305,7 +305,8 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
|
||||||
return (r, g, b)
|
return (r, g, b)
|
||||||
|
|
||||||
|
|
||||||
def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]:
|
def color_RGB_to_hsv(
|
||||||
|
iR: float, iG: float, iB: float) -> Tuple[float, float, float]:
|
||||||
"""Convert an rgb color to its hsv representation.
|
"""Convert an rgb color to its hsv representation.
|
||||||
|
|
||||||
Hue is scaled 0-360
|
Hue is scaled 0-360
|
||||||
|
@ -316,7 +317,7 @@ def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]:
|
||||||
return round(fHSV[0]*360, 3), round(fHSV[1]*100, 3), round(fHSV[2]*100, 3)
|
return round(fHSV[0]*360, 3), round(fHSV[1]*100, 3), round(fHSV[2]*100, 3)
|
||||||
|
|
||||||
|
|
||||||
def color_RGB_to_hs(iR: int, iG: int, iB: int) -> Tuple[float, float]:
|
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]:
|
||||||
"""Convert an rgb color to its hs representation."""
|
"""Convert an rgb color to its hs representation."""
|
||||||
return color_RGB_to_hsv(iR, iG, iB)[:2]
|
return color_RGB_to_hsv(iR, iG, iB)[:2]
|
||||||
|
|
||||||
|
@ -340,7 +341,7 @@ def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]:
|
||||||
def color_xy_to_hs(vX: float, vY: float) -> Tuple[float, float]:
|
def color_xy_to_hs(vX: float, vY: float) -> Tuple[float, float]:
|
||||||
"""Convert an xy color to its hs representation."""
|
"""Convert an xy color to its hs representation."""
|
||||||
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY))
|
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY))
|
||||||
return (h, s)
|
return h, s
|
||||||
|
|
||||||
|
|
||||||
def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
|
def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
|
||||||
|
@ -348,8 +349,7 @@ def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
|
||||||
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS))
|
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS))
|
||||||
|
|
||||||
|
|
||||||
def _match_max_scale(input_colors: Tuple[int, ...],
|
def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
|
||||||
output_colors: Tuple[int, ...]) -> Tuple[int, ...]:
|
|
||||||
"""Match the maximum value of the output to the input."""
|
"""Match the maximum value of the output to the input."""
|
||||||
max_in = max(input_colors)
|
max_in = max(input_colors)
|
||||||
max_out = max(output_colors)
|
max_out = max(output_colors)
|
||||||
|
@ -360,7 +360,7 @@ def _match_max_scale(input_colors: Tuple[int, ...],
|
||||||
return tuple(int(round(i * factor)) for i in output_colors)
|
return tuple(int(round(i * factor)) for i in output_colors)
|
||||||
|
|
||||||
|
|
||||||
def color_rgb_to_rgbw(r, g, b):
|
def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
|
||||||
"""Convert an rgb color to an rgbw representation."""
|
"""Convert an rgb color to an rgbw representation."""
|
||||||
# Calculate the white channel as the minimum of input rgb channels.
|
# Calculate the white channel as the minimum of input rgb channels.
|
||||||
# Subtract the white portion from the remaining rgb channels.
|
# Subtract the white portion from the remaining rgb channels.
|
||||||
|
@ -369,25 +369,25 @@ def color_rgb_to_rgbw(r, g, b):
|
||||||
|
|
||||||
# Match the output maximum value to the input. This ensures the full
|
# Match the output maximum value to the input. This ensures the full
|
||||||
# channel range is used.
|
# channel range is used.
|
||||||
return _match_max_scale((r, g, b), rgbw)
|
return _match_max_scale((r, g, b), rgbw) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def color_rgbw_to_rgb(r, g, b, w):
|
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]:
|
||||||
"""Convert an rgbw color to an rgb representation."""
|
"""Convert an rgbw color to an rgb representation."""
|
||||||
# Add the white channel back into the rgb channels.
|
# Add the white channel back into the rgb channels.
|
||||||
rgb = (r + w, g + w, b + w)
|
rgb = (r + w, g + w, b + w)
|
||||||
|
|
||||||
# Match the output maximum value to the input. This ensures the
|
# Match the output maximum value to the input. This ensures the
|
||||||
# output doesn't overflow.
|
# output doesn't overflow.
|
||||||
return _match_max_scale((r, g, b, w), rgb)
|
return _match_max_scale((r, g, b, w), rgb) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def color_rgb_to_hex(r, g, b):
|
def color_rgb_to_hex(r: int, g: int, b: int) -> str:
|
||||||
"""Return a RGB color from a hex color string."""
|
"""Return a RGB color from a hex color string."""
|
||||||
return '{0:02x}{1:02x}{2:02x}'.format(round(r), round(g), round(b))
|
return '{0:02x}{1:02x}{2:02x}'.format(round(r), round(g), round(b))
|
||||||
|
|
||||||
|
|
||||||
def rgb_hex_to_rgb_list(hex_string):
|
def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
|
||||||
"""Return an RGB color value list from a hex color string."""
|
"""Return an RGB color value list from a hex color string."""
|
||||||
return [int(hex_string[i:i + len(hex_string) // 3], 16)
|
return [int(hex_string[i:i + len(hex_string) // 3], 16)
|
||||||
for i in range(0,
|
for i in range(0,
|
||||||
|
@ -395,12 +395,14 @@ def rgb_hex_to_rgb_list(hex_string):
|
||||||
len(hex_string) // 3)]
|
len(hex_string) // 3)]
|
||||||
|
|
||||||
|
|
||||||
def color_temperature_to_hs(color_temperature_kelvin):
|
def color_temperature_to_hs(
|
||||||
|
color_temperature_kelvin: float) -> Tuple[float, float]:
|
||||||
"""Return an hs color from a color temperature in Kelvin."""
|
"""Return an hs color from a color temperature in Kelvin."""
|
||||||
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
|
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
|
||||||
|
|
||||||
|
|
||||||
def color_temperature_to_rgb(color_temperature_kelvin):
|
def color_temperature_to_rgb(
|
||||||
|
color_temperature_kelvin: float) -> Tuple[float, float, float]:
|
||||||
"""
|
"""
|
||||||
Return an RGB color from a color temperature in Kelvin.
|
Return an RGB color from a color temperature in Kelvin.
|
||||||
|
|
||||||
|
@ -421,7 +423,7 @@ def color_temperature_to_rgb(color_temperature_kelvin):
|
||||||
|
|
||||||
blue = _get_blue(tmp_internal)
|
blue = _get_blue(tmp_internal)
|
||||||
|
|
||||||
return (red, green, blue)
|
return red, green, blue
|
||||||
|
|
||||||
|
|
||||||
def _bound(color_component: float, minimum: float = 0,
|
def _bound(color_component: float, minimum: float = 0,
|
||||||
|
@ -464,11 +466,11 @@ def _get_blue(temperature: float) -> float:
|
||||||
return _bound(blue)
|
return _bound(blue)
|
||||||
|
|
||||||
|
|
||||||
def color_temperature_mired_to_kelvin(mired_temperature):
|
def color_temperature_mired_to_kelvin(mired_temperature: float) -> float:
|
||||||
"""Convert absolute mired shift to degrees kelvin."""
|
"""Convert absolute mired shift to degrees kelvin."""
|
||||||
return math.floor(1000000 / mired_temperature)
|
return math.floor(1000000 / mired_temperature)
|
||||||
|
|
||||||
|
|
||||||
def color_temperature_kelvin_to_mired(kelvin_temperature):
|
def color_temperature_kelvin_to_mired(kelvin_temperature: float) -> float:
|
||||||
"""Convert degrees kelvin to mired shift."""
|
"""Convert degrees kelvin to mired shift."""
|
||||||
return math.floor(1000000 / kelvin_temperature)
|
return math.floor(1000000 / kelvin_temperature)
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
"""Decorator utility functions."""
|
"""Decorator utility functions."""
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
class Registry(dict):
|
class Registry(dict):
|
||||||
"""Registry of items."""
|
"""Registry of items."""
|
||||||
|
|
||||||
def register(self, name):
|
def register(self, name: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
|
||||||
"""Return decorator to register item with a specific name."""
|
"""Return decorator to register item with a specific name."""
|
||||||
def decorator(func):
|
def decorator(func: CALLABLE_T) -> CALLABLE_T:
|
||||||
"""Register decorated function."""
|
"""Register decorated function."""
|
||||||
self[name] = func
|
self[name] = func
|
||||||
return func
|
return func
|
||||||
|
|
|
@ -71,14 +71,14 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
|
||||||
return dattim.astimezone(UTC)
|
return dattim.astimezone(UTC)
|
||||||
|
|
||||||
|
|
||||||
def as_timestamp(dt_value):
|
def as_timestamp(dt_value: dt.datetime) -> float:
|
||||||
"""Convert a date/time into a unix time (seconds since 1970)."""
|
"""Convert a date/time into a unix time (seconds since 1970)."""
|
||||||
if hasattr(dt_value, "timestamp"):
|
if hasattr(dt_value, "timestamp"):
|
||||||
parsed_dt = dt_value
|
parsed_dt = dt_value # type: Optional[dt.datetime]
|
||||||
else:
|
else:
|
||||||
parsed_dt = parse_datetime(str(dt_value))
|
parsed_dt = parse_datetime(str(dt_value))
|
||||||
if not parsed_dt:
|
if parsed_dt is None:
|
||||||
raise ValueError("not a valid date/time.")
|
raise ValueError("not a valid date/time.")
|
||||||
return parsed_dt.timestamp()
|
return parsed_dt.timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_time(time_str):
|
def parse_time(time_str: str) -> Optional[dt.time]:
|
||||||
"""Parse a time string (00:20:00) into Time object.
|
"""Parse a time string (00:20:00) into Time object.
|
||||||
|
|
||||||
Return None if invalid.
|
Return None if invalid.
|
||||||
|
|
|
@ -38,7 +38,7 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \
|
||||||
return {} if default is None else default
|
return {} if default is None else default
|
||||||
|
|
||||||
|
|
||||||
def save_json(filename: str, data: Union[List, Dict]):
|
def save_json(filename: str, data: Union[List, Dict]) -> None:
|
||||||
"""Save JSON data to a file.
|
"""Save JSON data to a file.
|
||||||
|
|
||||||
Returns True on success.
|
Returns True on success.
|
||||||
|
|
|
@ -33,7 +33,7 @@ LocationInfo = collections.namedtuple(
|
||||||
'use_metric'])
|
'use_metric'])
|
||||||
|
|
||||||
|
|
||||||
def detect_location_info():
|
def detect_location_info() -> Optional[LocationInfo]:
|
||||||
"""Detect location information."""
|
"""Detect location information."""
|
||||||
data = _get_freegeoip()
|
data = _get_freegeoip()
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ def distance(lat1: Optional[float], lon1: Optional[float],
|
||||||
return result * 1000
|
return result * 1000
|
||||||
|
|
||||||
|
|
||||||
def elevation(latitude, longitude):
|
def elevation(latitude: float, longitude: float) -> int:
|
||||||
"""Return elevation for given latitude and longitude."""
|
"""Return elevation for given latitude and longitude."""
|
||||||
try:
|
try:
|
||||||
req = requests.get(
|
req = requests.get(
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Logging utilities."""
|
"""Logging utilities."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from asyncio.events import AbstractEventLoop
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .async_ import run_coroutine_threadsafe
|
from .async_ import run_coroutine_threadsafe
|
||||||
|
|
||||||
|
@ -9,12 +11,12 @@ from .async_ import run_coroutine_threadsafe
|
||||||
class HideSensitiveDataFilter(logging.Filter):
|
class HideSensitiveDataFilter(logging.Filter):
|
||||||
"""Filter API password calls."""
|
"""Filter API password calls."""
|
||||||
|
|
||||||
def __init__(self, text):
|
def __init__(self, text: str) -> None:
|
||||||
"""Initialize sensitive data filter."""
|
"""Initialize sensitive data filter."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text = text
|
self.text = text
|
||||||
|
|
||||||
def filter(self, record):
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
"""Hide sensitive data in messages."""
|
"""Hide sensitive data in messages."""
|
||||||
record.msg = record.msg.replace(self.text, '*******')
|
record.msg = record.msg.replace(self.text, '*******')
|
||||||
|
|
||||||
|
@ -25,7 +27,8 @@ class HideSensitiveDataFilter(logging.Filter):
|
||||||
class AsyncHandler:
|
class AsyncHandler:
|
||||||
"""Logging handler wrapper to add an async layer."""
|
"""Logging handler wrapper to add an async layer."""
|
||||||
|
|
||||||
def __init__(self, loop, handler):
|
def __init__(
|
||||||
|
self, loop: AbstractEventLoop, handler: logging.Handler) -> None:
|
||||||
"""Initialize async logging handler wrapper."""
|
"""Initialize async logging handler wrapper."""
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
@ -45,11 +48,11 @@ class AsyncHandler:
|
||||||
|
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""Wrap close to handler."""
|
"""Wrap close to handler."""
|
||||||
self.emit(None)
|
self.emit(None)
|
||||||
|
|
||||||
async def async_close(self, blocking=False):
|
async def async_close(self, blocking: bool = False) -> None:
|
||||||
"""Close the handler.
|
"""Close the handler.
|
||||||
|
|
||||||
When blocking=True, will wait till closed.
|
When blocking=True, will wait till closed.
|
||||||
|
@ -60,7 +63,7 @@ class AsyncHandler:
|
||||||
while self._thread.is_alive():
|
while self._thread.is_alive():
|
||||||
await asyncio.sleep(0, loop=self.loop)
|
await asyncio.sleep(0, loop=self.loop)
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record: Optional[logging.LogRecord]) -> None:
|
||||||
"""Process a record."""
|
"""Process a record."""
|
||||||
ident = self.loop.__dict__.get("_thread_ident")
|
ident = self.loop.__dict__.get("_thread_ident")
|
||||||
|
|
||||||
|
@ -71,11 +74,11 @@ class AsyncHandler:
|
||||||
else:
|
else:
|
||||||
self.loop.call_soon_threadsafe(self._queue.put_nowait, record)
|
self.loop.call_soon_threadsafe(self._queue.put_nowait, record)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Return the string names."""
|
"""Return the string names."""
|
||||||
return str(self.handler)
|
return str(self.handler)
|
||||||
|
|
||||||
def _process(self):
|
def _process(self) -> None:
|
||||||
"""Process log in a thread."""
|
"""Process log in a thread."""
|
||||||
while True:
|
while True:
|
||||||
record = run_coroutine_threadsafe(
|
record = run_coroutine_threadsafe(
|
||||||
|
@ -87,34 +90,34 @@ class AsyncHandler:
|
||||||
|
|
||||||
self.handler.emit(record)
|
self.handler.emit(record)
|
||||||
|
|
||||||
def createLock(self):
|
def createLock(self) -> None:
|
||||||
"""Ignore lock stuff."""
|
"""Ignore lock stuff."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def acquire(self):
|
def acquire(self) -> None:
|
||||||
"""Ignore lock stuff."""
|
"""Ignore lock stuff."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def release(self):
|
def release(self) -> None:
|
||||||
"""Ignore lock stuff."""
|
"""Ignore lock stuff."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def level(self):
|
def level(self) -> int:
|
||||||
"""Wrap property level to handler."""
|
"""Wrap property level to handler."""
|
||||||
return self.handler.level
|
return self.handler.level
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def formatter(self):
|
def formatter(self) -> Optional[logging.Formatter]:
|
||||||
"""Wrap property formatter to handler."""
|
"""Wrap property formatter to handler."""
|
||||||
return self.handler.formatter
|
return self.handler.formatter
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""Wrap property set_name to handler."""
|
"""Wrap property set_name to handler."""
|
||||||
return self.handler.get_name()
|
return self.handler.get_name() # type: ignore
|
||||||
|
|
||||||
@name.setter
|
@name.setter
|
||||||
def name(self, name):
|
def name(self, name: str) -> None:
|
||||||
"""Wrap property get_name to handler."""
|
"""Wrap property get_name to handler."""
|
||||||
self.handler.name = name
|
self.handler.set_name(name) # type: ignore
|
||||||
|
|
|
@ -16,7 +16,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
INSTALL_LOCK = threading.Lock()
|
INSTALL_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def is_virtual_env():
|
def is_virtual_env() -> bool:
|
||||||
"""Return if we run in a virtual environtment."""
|
"""Return if we run in a virtual environtment."""
|
||||||
# Check supports venv && virtualenv
|
# Check supports venv && virtualenv
|
||||||
return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or
|
return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or
|
||||||
|
|
|
@ -4,7 +4,7 @@ import ssl
|
||||||
import certifi
|
import certifi
|
||||||
|
|
||||||
|
|
||||||
def client_context():
|
def client_context() -> ssl.SSLContext:
|
||||||
"""Return an SSL context for making requests."""
|
"""Return an SSL context for making requests."""
|
||||||
context = ssl.create_default_context(
|
context = ssl.create_default_context(
|
||||||
purpose=ssl.Purpose.SERVER_AUTH,
|
purpose=ssl.Purpose.SERVER_AUTH,
|
||||||
|
@ -13,7 +13,7 @@ def client_context():
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def server_context():
|
def server_context() -> ssl.SSLContext:
|
||||||
"""Return an SSL context following the Mozilla recommendations.
|
"""Return an SSL context following the Mozilla recommendations.
|
||||||
|
|
||||||
TLS configuration follows the best-practice guidelines specified here:
|
TLS configuration follows the best-practice guidelines specified here:
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import fnmatch
|
import fnmatch
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Union, List, Dict
|
from typing import Union, List, Dict, Iterator, overload, TypeVar
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
try:
|
try:
|
||||||
|
@ -22,7 +22,10 @@ from homeassistant.exceptions import HomeAssistantError
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
_SECRET_NAMESPACE = 'homeassistant'
|
_SECRET_NAMESPACE = 'homeassistant'
|
||||||
SECRET_YAML = 'secrets.yaml'
|
SECRET_YAML = 'secrets.yaml'
|
||||||
__SECRET_CACHE = {} # type: Dict
|
__SECRET_CACHE = {} # type: Dict[str, JSON_TYPE]
|
||||||
|
|
||||||
|
JSON_TYPE = Union[List, Dict, str]
|
||||||
|
DICT_T = TypeVar('DICT_T', bound=Dict)
|
||||||
|
|
||||||
|
|
||||||
class NodeListClass(list):
|
class NodeListClass(list):
|
||||||
|
@ -37,7 +40,42 @@ class NodeStrClass(str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _add_reference(obj, loader, node):
|
# pylint: disable=too-many-ancestors
|
||||||
|
class SafeLineLoader(yaml.SafeLoader):
|
||||||
|
"""Loader class that keeps track of line numbers."""
|
||||||
|
|
||||||
|
def compose_node(self, parent: yaml.nodes.Node,
|
||||||
|
index: int) -> yaml.nodes.Node:
|
||||||
|
"""Annotate a node with the first line it was seen."""
|
||||||
|
last_line = self.line # type: int
|
||||||
|
node = super(SafeLineLoader,
|
||||||
|
self).compose_node(parent, index) # type: yaml.nodes.Node
|
||||||
|
node.__line__ = last_line + 1 # type: ignore
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=pointless-statement
|
||||||
|
@overload
|
||||||
|
def _add_reference(obj: Union[list, NodeListClass],
|
||||||
|
loader: yaml.SafeLoader,
|
||||||
|
node: yaml.nodes.Node) -> NodeListClass: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload # noqa: F811
|
||||||
|
def _add_reference(obj: Union[str, NodeStrClass],
|
||||||
|
loader: yaml.SafeLoader,
|
||||||
|
node: yaml.nodes.Node) -> NodeStrClass: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload # noqa: F811
|
||||||
|
def _add_reference(obj: DICT_T,
|
||||||
|
loader: yaml.SafeLoader,
|
||||||
|
node: yaml.nodes.Node) -> DICT_T: ...
|
||||||
|
# pylint: enable=pointless-statement
|
||||||
|
|
||||||
|
|
||||||
|
def _add_reference(obj, loader: SafeLineLoader, # type: ignore # noqa: F811
|
||||||
|
node: yaml.nodes.Node):
|
||||||
"""Add file reference information to an object."""
|
"""Add file reference information to an object."""
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
obj = NodeListClass(obj)
|
obj = NodeListClass(obj)
|
||||||
|
@ -48,20 +86,7 @@ def _add_reference(obj, loader, node):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
def load_yaml(fname: str) -> JSON_TYPE:
|
||||||
class SafeLineLoader(yaml.SafeLoader):
|
|
||||||
"""Loader class that keeps track of line numbers."""
|
|
||||||
|
|
||||||
def compose_node(self, parent: yaml.nodes.Node, index) -> yaml.nodes.Node:
|
|
||||||
"""Annotate a node with the first line it was seen."""
|
|
||||||
last_line = self.line # type: int
|
|
||||||
node = super(SafeLineLoader,
|
|
||||||
self).compose_node(parent, index) # type: yaml.nodes.Node
|
|
||||||
node.__line__ = last_line + 1 # type: ignore
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(fname: str) -> Union[List, Dict]:
|
|
||||||
"""Load a YAML file."""
|
"""Load a YAML file."""
|
||||||
try:
|
try:
|
||||||
with open(fname, encoding='utf-8') as conf_file:
|
with open(fname, encoding='utf-8') as conf_file:
|
||||||
|
@ -83,12 +108,12 @@ def dump(_dict: dict) -> str:
|
||||||
.replace(': null\n', ':\n')
|
.replace(': null\n', ':\n')
|
||||||
|
|
||||||
|
|
||||||
def save_yaml(path, data):
|
def save_yaml(path: str, data: dict) -> None:
|
||||||
"""Save YAML to a file."""
|
"""Save YAML to a file."""
|
||||||
# Dump before writing to not truncate the file if dumping fails
|
# Dump before writing to not truncate the file if dumping fails
|
||||||
data = dump(data)
|
str_data = dump(data)
|
||||||
with open(path, 'w', encoding='utf-8') as outfile:
|
with open(path, 'w', encoding='utf-8') as outfile:
|
||||||
outfile.write(data)
|
outfile.write(str_data)
|
||||||
|
|
||||||
|
|
||||||
def clear_secret_cache() -> None:
|
def clear_secret_cache() -> None:
|
||||||
|
@ -100,7 +125,7 @@ def clear_secret_cache() -> None:
|
||||||
|
|
||||||
|
|
||||||
def _include_yaml(loader: SafeLineLoader,
|
def _include_yaml(loader: SafeLineLoader,
|
||||||
node: yaml.nodes.Node) -> Union[List, Dict]:
|
node: yaml.nodes.Node) -> JSON_TYPE:
|
||||||
"""Load another YAML file and embeds it using the !include tag.
|
"""Load another YAML file and embeds it using the !include tag.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -115,7 +140,7 @@ def _is_file_valid(name: str) -> bool:
|
||||||
return not name.startswith('.')
|
return not name.startswith('.')
|
||||||
|
|
||||||
|
|
||||||
def _find_files(directory: str, pattern: str):
|
def _find_files(directory: str, pattern: str) -> Iterator[str]:
|
||||||
"""Recursively load files in a directory."""
|
"""Recursively load files in a directory."""
|
||||||
for root, dirs, files in os.walk(directory, topdown=True):
|
for root, dirs, files in os.walk(directory, topdown=True):
|
||||||
dirs[:] = [d for d in dirs if _is_file_valid(d)]
|
dirs[:] = [d for d in dirs if _is_file_valid(d)]
|
||||||
|
@ -151,7 +176,7 @@ def _include_dir_merge_named_yaml(loader: SafeLineLoader,
|
||||||
|
|
||||||
|
|
||||||
def _include_dir_list_yaml(loader: SafeLineLoader,
|
def _include_dir_list_yaml(loader: SafeLineLoader,
|
||||||
node: yaml.nodes.Node):
|
node: yaml.nodes.Node) -> List[JSON_TYPE]:
|
||||||
"""Load multiple files from directory as a list."""
|
"""Load multiple files from directory as a list."""
|
||||||
loc = os.path.join(os.path.dirname(loader.name), node.value)
|
loc = os.path.join(os.path.dirname(loader.name), node.value)
|
||||||
return [load_yaml(f) for f in _find_files(loc, '*.yaml')
|
return [load_yaml(f) for f in _find_files(loc, '*.yaml')
|
||||||
|
@ -159,11 +184,11 @@ def _include_dir_list_yaml(loader: SafeLineLoader,
|
||||||
|
|
||||||
|
|
||||||
def _include_dir_merge_list_yaml(loader: SafeLineLoader,
|
def _include_dir_merge_list_yaml(loader: SafeLineLoader,
|
||||||
node: yaml.nodes.Node):
|
node: yaml.nodes.Node) -> JSON_TYPE:
|
||||||
"""Load multiple files from directory as a merged list."""
|
"""Load multiple files from directory as a merged list."""
|
||||||
loc = os.path.join(os.path.dirname(loader.name),
|
loc = os.path.join(os.path.dirname(loader.name),
|
||||||
node.value) # type: str
|
node.value) # type: str
|
||||||
merged_list = [] # type: List
|
merged_list = [] # type: List[JSON_TYPE]
|
||||||
for fname in _find_files(loc, '*.yaml'):
|
for fname in _find_files(loc, '*.yaml'):
|
||||||
if os.path.basename(fname) == SECRET_YAML:
|
if os.path.basename(fname) == SECRET_YAML:
|
||||||
continue
|
continue
|
||||||
|
@ -202,14 +227,14 @@ def _ordered_dict(loader: SafeLineLoader,
|
||||||
return _add_reference(OrderedDict(nodes), loader, node)
|
return _add_reference(OrderedDict(nodes), loader, node)
|
||||||
|
|
||||||
|
|
||||||
def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node):
|
def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node) -> JSON_TYPE:
|
||||||
"""Add line number and file name to Load YAML sequence."""
|
"""Add line number and file name to Load YAML sequence."""
|
||||||
obj, = loader.construct_yaml_seq(node)
|
obj, = loader.construct_yaml_seq(node)
|
||||||
return _add_reference(obj, loader, node)
|
return _add_reference(obj, loader, node)
|
||||||
|
|
||||||
|
|
||||||
def _env_var_yaml(loader: SafeLineLoader,
|
def _env_var_yaml(loader: SafeLineLoader,
|
||||||
node: yaml.nodes.Node):
|
node: yaml.nodes.Node) -> str:
|
||||||
"""Load environment variables and embed it into the configuration YAML."""
|
"""Load environment variables and embed it into the configuration YAML."""
|
||||||
args = node.value.split()
|
args = node.value.split()
|
||||||
|
|
||||||
|
@ -222,7 +247,7 @@ def _env_var_yaml(loader: SafeLineLoader,
|
||||||
raise HomeAssistantError(node.value)
|
raise HomeAssistantError(node.value)
|
||||||
|
|
||||||
|
|
||||||
def _load_secret_yaml(secret_path: str) -> Dict:
|
def _load_secret_yaml(secret_path: str) -> JSON_TYPE:
|
||||||
"""Load the secrets yaml from path."""
|
"""Load the secrets yaml from path."""
|
||||||
secret_path = os.path.join(secret_path, SECRET_YAML)
|
secret_path = os.path.join(secret_path, SECRET_YAML)
|
||||||
if secret_path in __SECRET_CACHE:
|
if secret_path in __SECRET_CACHE:
|
||||||
|
@ -248,7 +273,7 @@ def _load_secret_yaml(secret_path: str) -> Dict:
|
||||||
|
|
||||||
|
|
||||||
def _secret_yaml(loader: SafeLineLoader,
|
def _secret_yaml(loader: SafeLineLoader,
|
||||||
node: yaml.nodes.Node):
|
node: yaml.nodes.Node) -> JSON_TYPE:
|
||||||
"""Load secrets and embed it into the configuration YAML."""
|
"""Load secrets and embed it into the configuration YAML."""
|
||||||
secret_path = os.path.dirname(loader.name)
|
secret_path = os.path.dirname(loader.name)
|
||||||
while True:
|
while True:
|
||||||
|
@ -308,7 +333,8 @@ yaml.SafeLoader.add_constructor('!include_dir_merge_named',
|
||||||
|
|
||||||
# From: https://gist.github.com/miracle2k/3184458
|
# From: https://gist.github.com/miracle2k/3184458
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
def represent_odict(dump, tag, mapping, flow_style=None):
|
def represent_odict(dump, tag, mapping, # type: ignore
|
||||||
|
flow_style=None) -> yaml.MappingNode:
|
||||||
"""Like BaseRepresenter.represent_mapping but does not issue the sort()."""
|
"""Like BaseRepresenter.represent_mapping but does not issue the sort()."""
|
||||||
value = [] # type: list
|
value = [] # type: list
|
||||||
node = yaml.MappingNode(tag, value, flow_style=flow_style)
|
node = yaml.MappingNode(tag, value, flow_style=flow_style)
|
||||||
|
|
7
mypy.ini
7
mypy.ini
|
@ -2,11 +2,18 @@
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
follow_imports = silent
|
follow_imports = silent
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
warn_incomplete_stub = true
|
||||||
warn_redundant_casts = true
|
warn_redundant_casts = true
|
||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unused_configs = true
|
warn_unused_configs = true
|
||||||
warn_unused_ignores = true
|
warn_unused_ignores = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.*]
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.config_entries]
|
||||||
|
disallow_untyped_defs = false
|
||||||
|
|
||||||
[mypy-homeassistant.util.yaml]
|
[mypy-homeassistant.util.yaml]
|
||||||
warn_return_any = false
|
warn_return_any = false
|
||||||
|
|
||||||
|
|
|
@ -437,10 +437,12 @@ class TestAutomation(unittest.TestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}}):
|
}}):
|
||||||
automation.reload(self.hass)
|
with patch('homeassistant.config.find_config_file',
|
||||||
self.hass.block_till_done()
|
return_value=''):
|
||||||
# De-flake ?!
|
automation.reload(self.hass)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
# De-flake ?!
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert self.hass.states.get('automation.hello') is None
|
assert self.hass.states.get('automation.hello') is None
|
||||||
assert self.hass.states.get('automation.bye') is not None
|
assert self.hass.states.get('automation.bye') is not None
|
||||||
|
@ -485,8 +487,10 @@ class TestAutomation(unittest.TestCase):
|
||||||
|
|
||||||
with patch('homeassistant.config.load_yaml_config_file', autospec=True,
|
with patch('homeassistant.config.load_yaml_config_file', autospec=True,
|
||||||
return_value={automation.DOMAIN: 'not valid'}):
|
return_value={automation.DOMAIN: 'not valid'}):
|
||||||
automation.reload(self.hass)
|
with patch('homeassistant.config.find_config_file',
|
||||||
self.hass.block_till_done()
|
return_value=''):
|
||||||
|
automation.reload(self.hass)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert self.hass.states.get('automation.hello') is None
|
assert self.hass.states.get('automation.hello') is None
|
||||||
|
|
||||||
|
@ -521,8 +525,10 @@ class TestAutomation(unittest.TestCase):
|
||||||
|
|
||||||
with patch('homeassistant.config.load_yaml_config_file',
|
with patch('homeassistant.config.load_yaml_config_file',
|
||||||
side_effect=HomeAssistantError('bla')):
|
side_effect=HomeAssistantError('bla')):
|
||||||
automation.reload(self.hass)
|
with patch('homeassistant.config.find_config_file',
|
||||||
self.hass.block_till_done()
|
return_value=''):
|
||||||
|
automation.reload(self.hass)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert self.hass.states.get('automation.hello') is not None
|
assert self.hass.states.get('automation.hello') is not None
|
||||||
|
|
||||||
|
|
|
@ -365,8 +365,10 @@ class TestComponentsGroup(unittest.TestCase):
|
||||||
'icon': 'mdi:work',
|
'icon': 'mdi:work',
|
||||||
'view': True,
|
'view': True,
|
||||||
}}}):
|
}}}):
|
||||||
group.reload(self.hass)
|
with patch('homeassistant.config.find_config_file',
|
||||||
self.hass.block_till_done()
|
return_value=''):
|
||||||
|
group.reload(self.hass)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert sorted(self.hass.states.entity_ids()) == \
|
assert sorted(self.hass.states.entity_ids()) == \
|
||||||
['group.all_tests', 'group.hello']
|
['group.all_tests', 'group.hello']
|
||||||
|
|
|
@ -199,8 +199,10 @@ class TestScriptComponent(unittest.TestCase):
|
||||||
}
|
}
|
||||||
}]
|
}]
|
||||||
}}}):
|
}}}):
|
||||||
script.reload(self.hass)
|
with patch('homeassistant.config.find_config_file',
|
||||||
self.hass.block_till_done()
|
return_value=''):
|
||||||
|
script.reload(self.hass)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert self.hass.states.get(ENTITY_ID) is None
|
assert self.hass.states.get(ENTITY_ID) is None
|
||||||
assert not self.hass.services.has_service(script.DOMAIN, 'test')
|
assert not self.hass.services.has_service(script.DOMAIN, 'test')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue