Type hint additions (#26765)

This commit is contained in:
Ville Skyttä 2019-09-20 18:23:34 +03:00 committed by Paulus Schoutsen
parent 6a3132344c
commit 54242cd65c
10 changed files with 60 additions and 52 deletions

View file

@ -3,6 +3,7 @@ import asyncio
from functools import partial from functools import partial
import importlib import importlib
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -34,7 +35,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util.dt import parse_datetime, utcnow from homeassistant.util.dt import parse_datetime, utcnow
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any # mypy: no-check-untyped-defs, no-warn-return-any
DOMAIN = "automation" DOMAIN = "automation"
@ -281,11 +282,11 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
if enable_automation: if enable_automation:
await self.async_enable() await self.async_enable()
async def async_turn_on(self, **kwargs) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn the entity on and update the state.""" """Turn the entity on and update the state."""
await self.async_enable() await self.async_enable()
async def async_turn_off(self, **kwargs) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn the entity off.""" """Turn the entity off."""
await self.async_disable() await self.async_disable()

View file

@ -2,6 +2,7 @@
from datetime import timedelta from datetime import timedelta
import functools as ft import functools as ft
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -33,7 +34,7 @@ from homeassistant.const import (
) )
# mypy: allow-untyped-calls, allow-incomplete-defs, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -263,7 +264,7 @@ class CoverDevice(Entity):
"""Return if the cover is closed or not.""" """Return if the cover is closed or not."""
raise NotImplementedError() raise NotImplementedError()
def open_cover(self, **kwargs): def open_cover(self, **kwargs: Any) -> None:
"""Open the cover.""" """Open the cover."""
raise NotImplementedError() raise NotImplementedError()
@ -274,7 +275,7 @@ class CoverDevice(Entity):
""" """
return self.hass.async_add_job(ft.partial(self.open_cover, **kwargs)) return self.hass.async_add_job(ft.partial(self.open_cover, **kwargs))
def close_cover(self, **kwargs): def close_cover(self, **kwargs: Any) -> None:
"""Close cover.""" """Close cover."""
raise NotImplementedError() raise NotImplementedError()
@ -285,7 +286,7 @@ class CoverDevice(Entity):
""" """
return self.hass.async_add_job(ft.partial(self.close_cover, **kwargs)) return self.hass.async_add_job(ft.partial(self.close_cover, **kwargs))
def toggle(self, **kwargs) -> None: def toggle(self, **kwargs: Any) -> None:
"""Toggle the entity.""" """Toggle the entity."""
if self.is_closed: if self.is_closed:
self.open_cover(**kwargs) self.open_cover(**kwargs)
@ -323,7 +324,7 @@ class CoverDevice(Entity):
""" """
return self.hass.async_add_job(ft.partial(self.stop_cover, **kwargs)) return self.hass.async_add_job(ft.partial(self.stop_cover, **kwargs))
def open_cover_tilt(self, **kwargs): def open_cover_tilt(self, **kwargs: Any) -> None:
"""Open the cover tilt.""" """Open the cover tilt."""
pass pass
@ -334,7 +335,7 @@ class CoverDevice(Entity):
""" """
return self.hass.async_add_job(ft.partial(self.open_cover_tilt, **kwargs)) return self.hass.async_add_job(ft.partial(self.open_cover_tilt, **kwargs))
def close_cover_tilt(self, **kwargs): def close_cover_tilt(self, **kwargs: Any) -> None:
"""Close the cover tilt.""" """Close the cover tilt."""
pass pass
@ -369,7 +370,7 @@ class CoverDevice(Entity):
""" """
return self.hass.async_add_job(ft.partial(self.stop_cover_tilt, **kwargs)) return self.hass.async_add_job(ft.partial(self.stop_cover_tilt, **kwargs))
def toggle_tilt(self, **kwargs) -> None: def toggle_tilt(self, **kwargs: Any) -> None:
"""Toggle the entity.""" """Toggle the entity."""
if self.current_cover_tilt_position == 0: if self.current_cover_tilt_position == 0:
self.open_cover_tilt(**kwargs) self.open_cover_tilt(**kwargs)

View file

@ -4,6 +4,7 @@ import logging
import mimetypes import mimetypes
import os import os
import pathlib import pathlib
from typing import Optional, Set, Tuple
from aiohttp import web, web_urldispatcher, hdrs from aiohttp import web, web_urldispatcher, hdrs
import voluptuous as vol import voluptuous as vol
@ -22,7 +23,7 @@ from homeassistant.loader import bind_hass
from .storage import async_setup_frontend_storage from .storage import async_setup_frontend_storage
# mypy: allow-incomplete-defs, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
# Fix mimetypes for borked Windows machines # Fix mimetypes for borked Windows machines
# https://github.com/home-assistant/home-assistant-polymer/issues/3336 # https://github.com/home-assistant/home-assistant-polymer/issues/3336
@ -400,7 +401,9 @@ class IndexView(web_urldispatcher.AbstractResource):
"""Construct url for resource with additional params.""" """Construct url for resource with additional params."""
return URL("/") return URL("/")
async def resolve(self, request: web.Request): async def resolve(
self, request: web.Request
) -> Tuple[Optional[web_urldispatcher.UrlMappingMatchInfo], Set[str]]:
"""Resolve resource. """Resolve resource.
Return (UrlMappingMatchInfo, allowed_methods) pair. Return (UrlMappingMatchInfo, allowed_methods) pair.
@ -447,7 +450,7 @@ class IndexView(web_urldispatcher.AbstractResource):
return tpl return tpl
async def get(self, request: web.Request): async def get(self, request: web.Request) -> web.Response:
"""Serve the index page for panel pages.""" """Serve the index page for panel pages."""
hass = request.app["hass"] hass = request.app["hass"]

View file

@ -18,7 +18,7 @@ from homeassistant.util.yaml import dump
from .const import KEY_REAL_IP from .const import KEY_REAL_IP
# mypy: allow-incomplete-defs, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -165,7 +165,7 @@ class IpBan:
self.banned_at = banned_at or datetime.utcnow() self.banned_at = banned_at or datetime.utcnow()
async def async_load_ip_bans_config(hass: HomeAssistant, path: str): async def async_load_ip_bans_config(hass: HomeAssistant, path: str) -> List[IpBan]:
"""Load list of banned IPs from config file.""" """Load list of banned IPs from config file."""
ip_list: List[IpBan] = [] ip_list: List[IpBan] = []
@ -188,7 +188,7 @@ async def async_load_ip_bans_config(hass: HomeAssistant, path: str):
return ip_list return ip_list
def update_ip_bans_config(path: str, ip_ban: IpBan): def update_ip_bans_config(path: str, ip_ban: IpBan) -> None:
"""Update config file with new banned IP address.""" """Update config file with new banned IP address."""
with open(path, "a") as out: with open(path, "a") as out:
ip_ = { ip_ = {

View file

@ -36,7 +36,7 @@ from .const import (
) )
# mypy: allow-incomplete-defs, allow-untyped-defs # mypy: allow-untyped-defs
async def _async_reproduce_states( async def _async_reproduce_states(
@ -44,7 +44,7 @@ async def _async_reproduce_states(
) -> None: ) -> None:
"""Reproduce component states.""" """Reproduce component states."""
async def call_service(service: str, keys: Iterable): async def call_service(service: str, keys: Iterable) -> None:
"""Call service with set of attributes given.""" """Call service with set of attributes given."""
data = {} data = {}
data["entity_id"] = state.entity_id data["entity_id"] = state.entity_id

View file

@ -1,6 +1,6 @@
"""Light support for switch entities.""" """Light support for switch entities."""
import logging import logging
from typing import cast from typing import cast, Callable, Dict, Optional, Sequence
import voluptuous as vol import voluptuous as vol
@ -14,13 +14,14 @@ from homeassistant.const import (
) )
from homeassistant.core import State, callback from homeassistant.core import State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.components.light import PLATFORM_SCHEMA, Light from homeassistant.components.light import PLATFORM_SCHEMA, Light
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -35,7 +36,10 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
async def async_setup_platform( async def async_setup_platform(
hass: HomeAssistantType, config: ConfigType, async_add_entities, discovery_info=None hass: HomeAssistantType,
config: ConfigType,
async_add_entities: Callable[[Sequence[Entity], bool], None],
discovery_info: Optional[Dict] = None,
) -> None: ) -> None:
"""Initialize Light Switch platform.""" """Initialize Light Switch platform."""
async_add_entities( async_add_entities(
@ -105,7 +109,7 @@ class LightSwitch(Light):
@callback @callback
def async_state_changed_listener( def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State entity_id: str, old_state: State, new_state: State
): ) -> None:
"""Handle child updates.""" """Handle child updates."""
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)

View file

@ -52,7 +52,7 @@ from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import slugify as util_slugify from homeassistant.util import slugify as util_slugify
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any # mypy: no-check-untyped-defs, no-warn-return-any
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -95,7 +95,7 @@ def has_at_least_one_key(*keys: str) -> Callable:
return validate return validate
def has_at_most_one_key(*keys: str) -> Callable: def has_at_most_one_key(*keys: str) -> Callable[[Dict], Dict]:
"""Validate that zero keys exist or one key exists.""" """Validate that zero keys exist or one key exists."""
def validate(obj: Dict) -> Dict: def validate(obj: Dict) -> Dict:
@ -224,7 +224,7 @@ def entity_ids(value: Union[str, List]) -> List[str]:
comp_entity_ids = vol.Any(vol.All(vol.Lower, ENTITY_MATCH_ALL), entity_ids) comp_entity_ids = vol.Any(vol.All(vol.Lower, ENTITY_MATCH_ALL), entity_ids)
def entity_domain(domain: str): def entity_domain(domain: str) -> Callable[[Any], str]:
"""Validate that entity belong to domain.""" """Validate that entity belong to domain."""
def validate(value: Any) -> str: def validate(value: Any) -> str:
@ -235,7 +235,7 @@ def entity_domain(domain: str):
return validate return validate
def entities_domain(domain: str): def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]:
"""Validate that entities belong to domain.""" """Validate that entities belong to domain."""
def validate(values: Union[str, List]) -> List[str]: def validate(values: Union[str, List]) -> List[str]:
@ -284,7 +284,7 @@ time_period_dict = vol.All(
) )
def time(value) -> time_sys: def time(value: Any) -> time_sys:
"""Validate and transform a time.""" """Validate and transform a time."""
if isinstance(value, time_sys): if isinstance(value, time_sys):
return value return value
@ -300,7 +300,7 @@ def time(value) -> time_sys:
return time_val return time_val
def date(value) -> date_sys: def date(value: Any) -> date_sys:
"""Validate and transform a date.""" """Validate and transform a date."""
if isinstance(value, date_sys): if isinstance(value, date_sys):
return value return value
@ -439,7 +439,7 @@ def string(value: Any) -> str:
return str(value) return str(value)
def temperature_unit(value) -> str: def temperature_unit(value: Any) -> str:
"""Validate and transform temperature unit.""" """Validate and transform temperature unit."""
value = str(value).upper() value = str(value).upper()
if value == "C": if value == "C":
@ -578,7 +578,7 @@ def deprecated(
replacement_key: Optional[str] = None, replacement_key: Optional[str] = None,
invalidation_version: Optional[str] = None, invalidation_version: Optional[str] = None,
default: Optional[Any] = None, default: Optional[Any] = None,
): ) -> Callable[[Dict], Dict]:
""" """
Log key as deprecated and provide a replacement (if exists). Log key as deprecated and provide a replacement (if exists).
@ -626,7 +626,7 @@ def deprecated(
" deprecated, please remove it from your configuration" " deprecated, please remove it from your configuration"
) )
def check_for_invalid_version(value: Optional[Any]): def check_for_invalid_version(value: Optional[Any]) -> None:
"""Raise error if current version has reached invalidation.""" """Raise error if current version has reached invalidation."""
if not invalidation_version: if not invalidation_version:
return return
@ -641,7 +641,7 @@ def deprecated(
) )
) )
def validator(config: Dict): def validator(config: Dict) -> Dict:
"""Check if key is in config and log warning.""" """Check if key is in config and log warning."""
if key in config: if key in config:
value = config[key] value = config[key]

View file

@ -4,7 +4,7 @@ import logging
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from itertools import islice from itertools import islice
from typing import Optional, Sequence, Callable, Dict, List, Set, Tuple from typing import Optional, Sequence, Callable, Dict, List, Set, Tuple, Any
import voluptuous as vol import voluptuous as vol
@ -32,8 +32,7 @@ import homeassistant.util.dt as date_util
from homeassistant.util.async_ import run_coroutine_threadsafe, run_callback_threadsafe from homeassistant.util.async_ import run_coroutine_threadsafe, run_callback_threadsafe
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# mypy: no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -101,9 +100,9 @@ class Script:
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
sequence, sequence: Sequence[Dict[str, Any]],
name: Optional[str] = None, name: Optional[str] = None,
change_listener=None, change_listener: Optional[Callable[..., Any]] = None,
) -> None: ) -> None:
"""Initialize the script.""" """Initialize the script."""
self.hass = hass self.hass = hass

View file

@ -7,7 +7,7 @@ import random
import re import re
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Iterable from typing import Any, Iterable
import jinja2 import jinja2
from jinja2 import contextfilter, contextfunction from jinja2 import contextfilter, contextfunction
@ -25,13 +25,13 @@ from homeassistant.const import (
from homeassistant.core import State, callback, split_entity_id, valid_entity_id from homeassistant.core import State, callback, split_entity_id, valid_entity_id
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import location as loc_helper from homeassistant.helpers import location as loc_helper
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import convert, dt as dt_util, location as loc_util from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any # mypy: no-check-untyped-defs, no-warn-return-any
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -106,7 +106,7 @@ def extract_entities(template, variables=None):
return MATCH_ALL return MATCH_ALL
def _true(arg) -> bool: def _true(arg: Any) -> bool:
return True return True
@ -191,7 +191,7 @@ class Template:
"""Extract all entities for state_changed listener.""" """Extract all entities for state_changed listener."""
return extract_entities(self.template, variables) return extract_entities(self.template, variables)
def render(self, variables: TemplateVarsType = None, **kwargs): def render(self, variables: TemplateVarsType = None, **kwargs: Any) -> str:
"""Render given template.""" """Render given template."""
if variables is not None: if variables is not None:
kwargs.update(variables) kwargs.update(variables)
@ -201,7 +201,7 @@ class Template:
).result() ).result()
@callback @callback
def async_render(self, variables: TemplateVarsType = None, **kwargs) -> str: def async_render(self, variables: TemplateVarsType = None, **kwargs: Any) -> str:
"""Render given template. """Render given template.
This method must be run in the event loop. This method must be run in the event loop.
@ -218,7 +218,7 @@ class Template:
@callback @callback
def async_render_to_info( def async_render_to_info(
self, variables: TemplateVarsType = None, **kwargs self, variables: TemplateVarsType = None, **kwargs: Any
) -> RenderInfo: ) -> RenderInfo:
"""Render the template and collect an entity filter.""" """Render the template and collect an entity filter."""
assert self.hass and _RENDER_INFO not in self.hass.data assert self.hass and _RENDER_INFO not in self.hass.data
@ -479,7 +479,7 @@ def _resolve_state(hass, entity_id_or_state):
return None return None
def expand(hass, *args) -> Iterable[State]: def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]:
"""Expand out any groups into entity states.""" """Expand out any groups into entity states."""
search = list(args) search = list(args)
found = {} found = {}
@ -635,7 +635,7 @@ def distance(hass, *args):
) )
def is_state(hass, entity_id: str, state: State) -> bool: def is_state(hass: HomeAssistantType, entity_id: str, state: State) -> bool:
"""Test if a state is a specific value.""" """Test if a state is a specific value."""
state_obj = _get_state(hass, entity_id) state_obj = _get_state(hass, entity_id)
return state_obj is not None and state_obj.state == state return state_obj is not None and state_obj.state == state

View file

@ -5,7 +5,7 @@ import importlib
import logging import logging
import os import os
import sys import sys
from typing import List from typing import List, Optional, Sequence, Text
from homeassistant.bootstrap import async_mount_local_lib_path from homeassistant.bootstrap import async_mount_local_lib_path
from homeassistant.config import get_default_config_dir from homeassistant.config import get_default_config_dir
@ -13,7 +13,7 @@ from homeassistant.requirements import pip_kwargs
from homeassistant.util.package import install_package, is_virtual_env, is_installed from homeassistant.util.package import install_package, is_virtual_env, is_installed
# mypy: allow-untyped-defs, allow-incomplete-defs, no-warn-return-any # mypy: allow-untyped-defs, no-warn-return-any
def run(args: List) -> int: def run(args: List) -> int:
@ -62,13 +62,13 @@ def run(args: List) -> int:
return script.run(args[1:]) # type: ignore return script.run(args[1:]) # type: ignore
def extract_config_dir(args=None) -> str: def extract_config_dir(args: Optional[Sequence[Text]] = None) -> str:
"""Extract the config dir from the arguments or get the default.""" """Extract the config dir from the arguments or get the default."""
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-c", "--config", default=None) parser.add_argument("-c", "--config", default=None)
args = parser.parse_known_args(args)[0] parsed_args = parser.parse_known_args(args)[0]
return ( return (
os.path.join(os.getcwd(), args.config) os.path.join(os.getcwd(), parsed_args.config)
if args.config if parsed_args.config
else get_default_config_dir() else get_default_config_dir()
) )