From 54242cd65c1d69d88ec366f545313cf003e9cbbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Fri, 20 Sep 2019 18:23:34 +0300 Subject: [PATCH] Type hint additions (#26765) --- .../components/automation/__init__.py | 7 ++++--- homeassistant/components/cover/__init__.py | 15 +++++++------- homeassistant/components/frontend/__init__.py | 9 ++++++--- homeassistant/components/http/ban.py | 6 +++--- .../media_player/reproduce_state.py | 4 ++-- homeassistant/components/switch/light.py | 12 +++++++---- homeassistant/helpers/config_validation.py | 20 +++++++++---------- homeassistant/helpers/script.py | 9 ++++----- homeassistant/helpers/template.py | 18 ++++++++--------- homeassistant/scripts/__init__.py | 12 +++++------ 10 files changed, 60 insertions(+), 52 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 9e08a9cff1f..f0529f126f1 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -3,6 +3,7 @@ import asyncio from functools import partial import importlib import logging +from typing import Any import voluptuous as vol @@ -34,7 +35,7 @@ from homeassistant.loader import bind_hass from homeassistant.util.dt import parse_datetime, utcnow -# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs +# mypy: allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs, no-warn-return-any DOMAIN = "automation" @@ -281,11 +282,11 @@ class AutomationEntity(ToggleEntity, RestoreEntity): if enable_automation: await self.async_enable() - async def async_turn_on(self, **kwargs) -> None: + async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on and update the state.""" await self.async_enable() - async def async_turn_off(self, **kwargs) -> None: + async def async_turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" await self.async_disable() diff --git a/homeassistant/components/cover/__init__.py b/homeassistant/components/cover/__init__.py index d491765bb00..8d2b4430fe1 100644 --- a/homeassistant/components/cover/__init__.py +++ b/homeassistant/components/cover/__init__.py @@ -2,6 +2,7 @@ from datetime import timedelta import functools as ft import logging +from typing import Any import voluptuous as vol @@ -33,7 +34,7 @@ from homeassistant.const import ( ) -# mypy: allow-untyped-calls, allow-incomplete-defs, allow-untyped-defs +# mypy: allow-untyped-calls, allow-untyped-defs _LOGGER = logging.getLogger(__name__) @@ -263,7 +264,7 @@ class CoverDevice(Entity): """Return if the cover is closed or not.""" raise NotImplementedError() - def open_cover(self, **kwargs): + def open_cover(self, **kwargs: Any) -> None: """Open the cover.""" raise NotImplementedError() @@ -274,7 +275,7 @@ class CoverDevice(Entity): """ return self.hass.async_add_job(ft.partial(self.open_cover, **kwargs)) - def close_cover(self, **kwargs): + def close_cover(self, **kwargs: Any) -> None: """Close cover.""" raise NotImplementedError() @@ -285,7 +286,7 @@ class CoverDevice(Entity): """ return self.hass.async_add_job(ft.partial(self.close_cover, **kwargs)) - def toggle(self, **kwargs) -> None: + def toggle(self, **kwargs: Any) -> None: """Toggle the entity.""" if self.is_closed: self.open_cover(**kwargs) @@ -323,7 +324,7 @@ class CoverDevice(Entity): """ return self.hass.async_add_job(ft.partial(self.stop_cover, **kwargs)) - def open_cover_tilt(self, **kwargs): + def open_cover_tilt(self, **kwargs: Any) -> None: """Open the cover tilt.""" pass @@ -334,7 +335,7 @@ class CoverDevice(Entity): """ return self.hass.async_add_job(ft.partial(self.open_cover_tilt, **kwargs)) - def close_cover_tilt(self, **kwargs): + def close_cover_tilt(self, **kwargs: Any) -> None: """Close the cover tilt.""" pass @@ -369,7 +370,7 @@ class CoverDevice(Entity): """ return self.hass.async_add_job(ft.partial(self.stop_cover_tilt, **kwargs)) - def toggle_tilt(self, **kwargs) -> None: + def toggle_tilt(self, **kwargs: Any) -> None: """Toggle the entity.""" if self.current_cover_tilt_position == 0: self.open_cover_tilt(**kwargs) diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index 7298ce8c1d0..8ef662ec878 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -4,6 +4,7 @@ import logging import mimetypes import os import pathlib +from typing import Optional, Set, Tuple from aiohttp import web, web_urldispatcher, hdrs import voluptuous as vol @@ -22,7 +23,7 @@ from homeassistant.loader import bind_hass from .storage import async_setup_frontend_storage -# mypy: allow-incomplete-defs, allow-untyped-defs, no-check-untyped-defs +# mypy: allow-untyped-defs, no-check-untyped-defs # Fix mimetypes for borked Windows machines # https://github.com/home-assistant/home-assistant-polymer/issues/3336 @@ -400,7 +401,9 @@ class IndexView(web_urldispatcher.AbstractResource): """Construct url for resource with additional params.""" return URL("/") - async def resolve(self, request: web.Request): + async def resolve( + self, request: web.Request + ) -> Tuple[Optional[web_urldispatcher.UrlMappingMatchInfo], Set[str]]: """Resolve resource. Return (UrlMappingMatchInfo, allowed_methods) pair. @@ -447,7 +450,7 @@ class IndexView(web_urldispatcher.AbstractResource): return tpl - async def get(self, request: web.Request): + async def get(self, request: web.Request) -> web.Response: """Serve the index page for panel pages.""" hass = request.app["hass"] diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index d8fa8853c7f..7d1e24f3698 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -18,7 +18,7 @@ from homeassistant.util.yaml import dump from .const import KEY_REAL_IP -# mypy: allow-incomplete-defs, allow-untyped-defs, no-check-untyped-defs +# mypy: allow-untyped-defs, no-check-untyped-defs _LOGGER = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class IpBan: self.banned_at = banned_at or datetime.utcnow() -async def async_load_ip_bans_config(hass: HomeAssistant, path: str): +async def async_load_ip_bans_config(hass: HomeAssistant, path: str) -> List[IpBan]: """Load list of banned IPs from config file.""" ip_list: List[IpBan] = [] @@ -188,7 +188,7 @@ async def async_load_ip_bans_config(hass: HomeAssistant, path: str): return ip_list -def update_ip_bans_config(path: str, ip_ban: IpBan): +def update_ip_bans_config(path: str, ip_ban: IpBan) -> None: """Update config file with new banned IP address.""" with open(path, "a") as out: ip_ = { diff --git a/homeassistant/components/media_player/reproduce_state.py b/homeassistant/components/media_player/reproduce_state.py index 4eba4657d95..dac08afe471 100644 --- a/homeassistant/components/media_player/reproduce_state.py +++ b/homeassistant/components/media_player/reproduce_state.py @@ -36,7 +36,7 @@ from .const import ( ) -# mypy: allow-incomplete-defs, allow-untyped-defs +# mypy: allow-untyped-defs async def _async_reproduce_states( @@ -44,7 +44,7 @@ async def _async_reproduce_states( ) -> None: """Reproduce component states.""" - async def call_service(service: str, keys: Iterable): + async def call_service(service: str, keys: Iterable) -> None: """Call service with set of attributes given.""" data = {} data["entity_id"] = state.entity_id diff --git a/homeassistant/components/switch/light.py b/homeassistant/components/switch/light.py index 2027a8fc458..8f3b5d87f8c 100644 --- a/homeassistant/components/switch/light.py +++ b/homeassistant/components/switch/light.py @@ -1,6 +1,6 @@ """Light support for switch entities.""" import logging -from typing import cast +from typing import cast, Callable, Dict, Optional, Sequence import voluptuous as vol @@ -14,13 +14,14 @@ from homeassistant.const import ( ) from homeassistant.core import State, callback import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.components.light import PLATFORM_SCHEMA, Light -# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs +# mypy: allow-untyped-calls, allow-untyped-defs _LOGGER = logging.getLogger(__name__) @@ -35,7 +36,10 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( async def async_setup_platform( - hass: HomeAssistantType, config: ConfigType, async_add_entities, discovery_info=None + hass: HomeAssistantType, + config: ConfigType, + async_add_entities: Callable[[Sequence[Entity], bool], None], + discovery_info: Optional[Dict] = None, ) -> None: """Initialize Light Switch platform.""" async_add_entities( @@ -105,7 +109,7 @@ class LightSwitch(Light): @callback def async_state_changed_listener( entity_id: str, old_state: State, new_state: State - ): + ) -> None: """Handle child updates.""" self.async_schedule_update_ha_state(True) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index e53954a65dd..952fa41c42c 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -52,7 +52,7 @@ from homeassistant.helpers.logging import KeywordStyleAdapter from homeassistant.util import slugify as util_slugify -# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs +# mypy: allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs, no-warn-return-any # pylint: disable=invalid-name @@ -95,7 +95,7 @@ def has_at_least_one_key(*keys: str) -> Callable: return validate -def has_at_most_one_key(*keys: str) -> Callable: +def has_at_most_one_key(*keys: str) -> Callable[[Dict], Dict]: """Validate that zero keys exist or one key exists.""" def validate(obj: Dict) -> Dict: @@ -224,7 +224,7 @@ def entity_ids(value: Union[str, List]) -> List[str]: comp_entity_ids = vol.Any(vol.All(vol.Lower, ENTITY_MATCH_ALL), entity_ids) -def entity_domain(domain: str): +def entity_domain(domain: str) -> Callable[[Any], str]: """Validate that entity belong to domain.""" def validate(value: Any) -> str: @@ -235,7 +235,7 @@ def entity_domain(domain: str): return validate -def entities_domain(domain: str): +def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]: """Validate that entities belong to domain.""" def validate(values: Union[str, List]) -> List[str]: @@ -284,7 +284,7 @@ time_period_dict = vol.All( ) -def time(value) -> time_sys: +def time(value: Any) -> time_sys: """Validate and transform a time.""" if isinstance(value, time_sys): return value @@ -300,7 +300,7 @@ def time(value) -> time_sys: return time_val -def date(value) -> date_sys: +def date(value: Any) -> date_sys: """Validate and transform a date.""" if isinstance(value, date_sys): return value @@ -439,7 +439,7 @@ def string(value: Any) -> str: return str(value) -def temperature_unit(value) -> str: +def temperature_unit(value: Any) -> str: """Validate and transform temperature unit.""" value = str(value).upper() if value == "C": @@ -578,7 +578,7 @@ def deprecated( replacement_key: Optional[str] = None, invalidation_version: Optional[str] = None, default: Optional[Any] = None, -): +) -> Callable[[Dict], Dict]: """ Log key as deprecated and provide a replacement (if exists). @@ -626,7 +626,7 @@ def deprecated( " deprecated, please remove it from your configuration" ) - def check_for_invalid_version(value: Optional[Any]): + def check_for_invalid_version(value: Optional[Any]) -> None: """Raise error if current version has reached invalidation.""" if not invalidation_version: return @@ -641,7 +641,7 @@ def deprecated( ) ) - def validator(config: Dict): + def validator(config: Dict) -> Dict: """Check if key is in config and log warning.""" if key in config: value = config[key] diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 0b569e2d4ad..23728b65109 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -4,7 +4,7 @@ import logging from contextlib import suppress from datetime import datetime from itertools import islice -from typing import Optional, Sequence, Callable, Dict, List, Set, Tuple +from typing import Optional, Sequence, Callable, Dict, List, Set, Tuple, Any import voluptuous as vol @@ -32,8 +32,7 @@ import homeassistant.util.dt as date_util from homeassistant.util.async_ import run_coroutine_threadsafe, run_callback_threadsafe -# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs +# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs _LOGGER = logging.getLogger(__name__) @@ -101,9 +100,9 @@ class Script: def __init__( self, hass: HomeAssistant, - sequence, + sequence: Sequence[Dict[str, Any]], name: Optional[str] = None, - change_listener=None, + change_listener: Optional[Callable[..., Any]] = None, ) -> None: """Initialize the script.""" self.hass = hass diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 98e3849bfb6..9af1998e894 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -7,7 +7,7 @@ import random import re from datetime import datetime from functools import wraps -from typing import Iterable +from typing import Any, Iterable import jinja2 from jinja2 import contextfilter, contextfunction @@ -25,13 +25,13 @@ from homeassistant.const import ( from homeassistant.core import State, callback, split_entity_id, valid_entity_id from homeassistant.exceptions import TemplateError from homeassistant.helpers import location as loc_helper -from homeassistant.helpers.typing import TemplateVarsType +from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType from homeassistant.loader import bind_hass from homeassistant.util import convert, dt as dt_util, location as loc_util from homeassistant.util.async_ import run_callback_threadsafe -# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs +# mypy: allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs, no-warn-return-any _LOGGER = logging.getLogger(__name__) @@ -106,7 +106,7 @@ def extract_entities(template, variables=None): return MATCH_ALL -def _true(arg) -> bool: +def _true(arg: Any) -> bool: return True @@ -191,7 +191,7 @@ class Template: """Extract all entities for state_changed listener.""" return extract_entities(self.template, variables) - def render(self, variables: TemplateVarsType = None, **kwargs): + def render(self, variables: TemplateVarsType = None, **kwargs: Any) -> str: """Render given template.""" if variables is not None: kwargs.update(variables) @@ -201,7 +201,7 @@ class Template: ).result() @callback - def async_render(self, variables: TemplateVarsType = None, **kwargs) -> str: + def async_render(self, variables: TemplateVarsType = None, **kwargs: Any) -> str: """Render given template. This method must be run in the event loop. @@ -218,7 +218,7 @@ class Template: @callback def async_render_to_info( - self, variables: TemplateVarsType = None, **kwargs + self, variables: TemplateVarsType = None, **kwargs: Any ) -> RenderInfo: """Render the template and collect an entity filter.""" assert self.hass and _RENDER_INFO not in self.hass.data @@ -479,7 +479,7 @@ def _resolve_state(hass, entity_id_or_state): return None -def expand(hass, *args) -> Iterable[State]: +def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]: """Expand out any groups into entity states.""" search = list(args) found = {} @@ -635,7 +635,7 @@ def distance(hass, *args): ) -def is_state(hass, entity_id: str, state: State) -> bool: +def is_state(hass: HomeAssistantType, entity_id: str, state: State) -> bool: """Test if a state is a specific value.""" state_obj = _get_state(hass, entity_id) return state_obj is not None and state_obj.state == state diff --git a/homeassistant/scripts/__init__.py b/homeassistant/scripts/__init__.py index 0a9bac30188..00f5984c58b 100644 --- a/homeassistant/scripts/__init__.py +++ b/homeassistant/scripts/__init__.py @@ -5,7 +5,7 @@ import importlib import logging import os import sys -from typing import List +from typing import List, Optional, Sequence, Text from homeassistant.bootstrap import async_mount_local_lib_path from homeassistant.config import get_default_config_dir @@ -13,7 +13,7 @@ from homeassistant.requirements import pip_kwargs from homeassistant.util.package import install_package, is_virtual_env, is_installed -# mypy: allow-untyped-defs, allow-incomplete-defs, no-warn-return-any +# mypy: allow-untyped-defs, no-warn-return-any def run(args: List) -> int: @@ -62,13 +62,13 @@ def run(args: List) -> int: return script.run(args[1:]) # type: ignore -def extract_config_dir(args=None) -> str: +def extract_config_dir(args: Optional[Sequence[Text]] = None) -> str: """Extract the config dir from the arguments or get the default.""" parser = argparse.ArgumentParser(add_help=False) parser.add_argument("-c", "--config", default=None) - args = parser.parse_known_args(args)[0] + parsed_args = parser.parse_known_args(args)[0] return ( - os.path.join(os.getcwd(), args.config) - if args.config + os.path.join(os.getcwd(), parsed_args.config) + if parsed_args.config else get_default_config_dir() )