From 868eb3c735c2136cf81a5f18a19cd08a7d0522a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Sun, 22 Dec 2019 20:51:39 +0200 Subject: [PATCH] More helpers type improvements (#30145) --- homeassistant/helpers/check_config.py | 26 +++-- homeassistant/helpers/config_validation.py | 115 ++++++++++++--------- homeassistant/helpers/device_registry.py | 8 +- homeassistant/helpers/entity_registry.py | 42 ++++---- homeassistant/helpers/logging.py | 23 +++-- 5 files changed, 124 insertions(+), 90 deletions(-) diff --git a/homeassistant/helpers/check_config.py b/homeassistant/helpers/check_config.py index 81e654247b7..1b1e136ed89 100644 --- a/homeassistant/helpers/check_config.py +++ b/homeassistant/helpers/check_config.py @@ -1,6 +1,6 @@ """Helper to check the configuration file.""" -from collections import OrderedDict, namedtuple -from typing import List +from collections import OrderedDict +from typing import List, NamedTuple, Optional import attr import voluptuous as vol @@ -19,15 +19,20 @@ from homeassistant.config import ( ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.typing import ConfigType from homeassistant.requirements import ( RequirementsNotFound, async_get_integration_with_requirements, ) import homeassistant.util.yaml.loader as yaml_loader -# mypy: allow-untyped-calls, allow-untyped-defs, no-warn-return-any -CheckConfigError = namedtuple("CheckConfigError", "message domain config") +class CheckConfigError(NamedTuple): + """Configuration check error.""" + + message: str + domain: Optional[str] + config: Optional[ConfigType] @attr.s @@ -36,7 +41,12 @@ class HomeAssistantConfig(OrderedDict): errors: List[CheckConfigError] = attr.ib(default=attr.Factory(list)) - def add_error(self, message, domain=None, config=None): + def add_error( + self, + message: str, + domain: Optional[str] = None, + config: Optional[ConfigType] = None, + ) -> "HomeAssistantConfig": """Add a single error.""" self.errors.append(CheckConfigError(str(message), domain, config)) return self @@ -55,7 +65,9 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig config_dir = hass.config.config_dir result = HomeAssistantConfig() - def _pack_error(package, component, config, message): + def _pack_error( + package: str, component: str, config: ConfigType, message: str + ) -> None: """Handle errors from packages: _log_pkg_error.""" message = "Package {} setup failed. Component {} {}".format( package, component, message @@ -64,7 +76,7 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig pack_config = core_config[CONF_PACKAGES].get(package, config) result.add_error(message, domain, pack_config) - def _comp_error(ex, domain, config): + def _comp_error(ex: Exception, domain: str, config: ConfigType) -> None: """Handle errors from components: async_log_exception.""" result.add_error(_format_config_error(ex, domain, config), domain, config) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 5787db65102..035e1f678bf 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -5,13 +5,26 @@ from datetime import ( time as time_sys, timedelta, ) +from enum import Enum import inspect import logging from numbers import Number import os import re from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + List, + Optional, + Pattern, + Type, + TypeVar, + Union, + cast, +) from urllib.parse import urlparse from uuid import UUID @@ -48,12 +61,11 @@ from homeassistant.const import ( ) from homeassistant.core import split_entity_id, valid_entity_id from homeassistant.exceptions import TemplateError +from homeassistant.helpers import template as template_helper from homeassistant.helpers.logging import KeywordStyleAdapter from homeassistant.util import slugify as util_slugify import homeassistant.util.dt as dt_util -# mypy: allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs, no-warn-return-any # pylint: disable=invalid-name TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'" @@ -126,7 +138,7 @@ def boolean(value: Any) -> bool: raise vol.Invalid("invalid boolean value {}".format(value)) -def isdevice(value): +def isdevice(value: Any) -> str: """Validate that value is a real device.""" try: os.stat(value) @@ -135,19 +147,19 @@ def isdevice(value): raise vol.Invalid("No device at {} found".format(value)) -def matches_regex(regex): +def matches_regex(regex: str) -> Callable[[Any], str]: """Validate that the value is a string that matches a regex.""" - regex = re.compile(regex) + compiled = re.compile(regex) def validator(value: Any) -> str: """Validate that value matches the given regex.""" if not isinstance(value, str): raise vol.Invalid("not a string value: {}".format(value)) - if not regex.match(value): + if not compiled.match(value): raise vol.Invalid( "value {} does not match regular expression {}".format( - value, regex.pattern + value, compiled.pattern ) ) @@ -156,14 +168,14 @@ def matches_regex(regex): return validator -def is_regex(value): +def is_regex(value: Any) -> Pattern[Any]: """Validate that a string is a valid regular expression.""" try: r = re.compile(value) return r except TypeError: raise vol.Invalid( - "value {} is of the wrong type for a regular " "expression".format(value) + "value {} is of the wrong type for a regular expression".format(value) ) except re.error: raise vol.Invalid("value {} is not a valid regular expression".format(value)) @@ -204,9 +216,9 @@ def ensure_list(value: Union[T, List[T], None]) -> List[T]: def entity_id(value: Any) -> str: """Validate Entity ID.""" - value = string(value).lower() - if valid_entity_id(value): - return value + str_value = string(value).lower() + if valid_entity_id(str_value): + return str_value raise vol.Invalid("Entity ID {} is an invalid entity id".format(value)) @@ -253,17 +265,17 @@ def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]: return validate -def enum(enumClass): +def enum(enumClass: Type[Enum]) -> vol.All: """Create validator for specified enum.""" return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__) -def icon(value): +def icon(value: Any) -> str: """Validate icon.""" - value = str(value) + str_value = str(value) - if ":" in value: - return value + if ":" in str_value: + return str_value raise vol.Invalid('Icons should be specified in the form "prefix:name"') @@ -362,7 +374,7 @@ def time_period_seconds(value: Union[int, str]) -> timedelta: time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict) -def match_all(value): +def match_all(value: T) -> T: """Validate that matches all values.""" return value @@ -382,12 +394,12 @@ def remove_falsy(value: List[T]) -> List[T]: return [v for v in value if v] -def service(value): +def service(value: Any) -> str: """Validate service.""" # Services use same format as entities so we can use same helper. - value = string(value).lower() - if valid_entity_id(value): - return value + str_value = string(value).lower() + if valid_entity_id(str_value): + return str_value raise vol.Invalid("Service {} does not match format .".format(value)) @@ -407,7 +419,7 @@ def schema_with_slug_keys(value_schema: Union[T, Callable]) -> Callable: for key in value.keys(): slug(key) - return schema(value) + return cast(Dict, schema(value)) return verify @@ -416,10 +428,10 @@ def slug(value: Any) -> str: """Validate value is a valid slug.""" if value is None: raise vol.Invalid("Slug should not be None") - value = str(value) - slg = util_slugify(value) - if value == slg: - return value + str_value = str(value) + slg = util_slugify(str_value) + if str_value == slg: + return str_value raise vol.Invalid("invalid slug {} (try {})".format(value, slg)) @@ -458,42 +470,41 @@ unit_system = vol.All( ) -def template(value): +def template(value: Optional[Any]) -> template_helper.Template: """Validate a jinja2 template.""" - from homeassistant.helpers import template as template_helper if value is None: raise vol.Invalid("template value is None") if isinstance(value, (list, dict, template_helper.Template)): raise vol.Invalid("template value should be a string") - value = template_helper.Template(str(value)) + template_value = template_helper.Template(str(value)) # type: ignore try: - value.ensure_valid() - return value + template_value.ensure_valid() + return cast(template_helper.Template, template_value) except TemplateError as ex: raise vol.Invalid("invalid template ({})".format(ex)) -def template_complex(value): +def template_complex(value: Any) -> Any: """Validate a complex jinja2 template.""" if isinstance(value, list): - return_value = value.copy() - for idx, element in enumerate(return_value): - return_value[idx] = template_complex(element) - return return_value + return_list = value.copy() + for idx, element in enumerate(return_list): + return_list[idx] = template_complex(element) + return return_list if isinstance(value, dict): - return_value = value.copy() - for key, element in return_value.items(): - return_value[key] = template_complex(element) - return return_value + return_dict = value.copy() + for key, element in return_dict.items(): + return_dict[key] = template_complex(element) + return return_dict if isinstance(value, str): return template(value) return value -def datetime(value): +def datetime(value: Any) -> datetime_sys: """Validate datetime.""" if isinstance(value, datetime_sys): return value @@ -509,7 +520,7 @@ def datetime(value): return date_val -def time_zone(value): +def time_zone(value: str) -> str: """Validate timezone.""" if dt_util.get_time_zone(value) is not None: return value @@ -522,7 +533,7 @@ def time_zone(value): weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)]) -def socket_timeout(value): +def socket_timeout(value: Optional[Any]) -> object: """Validate timeout float > 0.0. None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object. @@ -544,12 +555,12 @@ def url(value: Any) -> str: url_in = str(value) if urlparse(url_in).scheme in ["http", "https"]: - return vol.Schema(vol.Url())(url_in) + return cast(str, vol.Schema(vol.Url())(url_in)) raise vol.Invalid("invalid url") -def x10_address(value): +def x10_address(value: str) -> str: """Validate an x10 address.""" regex = re.compile(r"([A-Pa-p]{1})(?:[2-9]|1[0-6]?)$") if not regex.match(value): @@ -557,7 +568,7 @@ def x10_address(value): return str(value).lower() -def uuid4_hex(value): +def uuid4_hex(value: Any) -> str: """Validate a v4 UUID in hex format.""" try: result = UUID(value, version=4) @@ -678,10 +689,12 @@ def deprecated( # Validator helpers -def key_dependency(key, dependency): +def key_dependency( + key: Hashable, dependency: Hashable +) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]: """Validate that all dependencies exist for key.""" - def validator(value): + def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]: """Test dependencies.""" if not isinstance(value, dict): raise vol.Invalid("key dependencies require a dict") @@ -696,7 +709,7 @@ def key_dependency(key, dependency): return validator -def custom_serializer(schema): +def custom_serializer(schema: Any) -> Any: """Serialize additional types for voluptuous_serialize.""" if schema is positive_time_period_dict: return {"type": "positive_time_period_dict"} diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 4818de83cb9..512334c8d3c 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -12,8 +12,7 @@ from homeassistant.loader import bind_hass from .typing import HomeAssistantType -# mypy: allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs, no-warn-return-any +# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs _LOGGER = logging.getLogger(__name__) _UNDEF = object() @@ -71,10 +70,11 @@ def format_mac(mac: str) -> str: class DeviceRegistry: """Class to hold a registry of devices.""" - def __init__(self, hass): + devices: Dict[str, DeviceEntry] + + def __init__(self, hass: HomeAssistantType) -> None: """Initialize the device registry.""" self.hass = hass - self.devices = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) @callback diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index a5bd62d973c..5eb79965880 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -11,7 +11,7 @@ import asyncio from collections import OrderedDict from itertools import chain import logging -from typing import Any, Dict, Iterable, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast import attr @@ -23,6 +23,9 @@ from homeassistant.util.yaml import load_yaml from .typing import HomeAssistantType +if TYPE_CHECKING: + from homeassistant.config_entries import ConfigEntry # noqa: F401 + # mypy: allow-untyped-defs, no-check-untyped-defs PATH_REGISTRY = "entity_registry.yaml" @@ -48,7 +51,7 @@ class RegistryEntry: unique_id = attr.ib(type=str) platform = attr.ib(type=str) name = attr.ib(type=str, default=None) - device_id = attr.ib(type=str, default=None) + device_id: Optional[str] = attr.ib(default=None) config_entry_id: Optional[str] = attr.ib(default=None) disabled_by = attr.ib( type=Optional[str], @@ -135,16 +138,16 @@ class EntityRegistry: @callback def async_get_or_create( self, - domain, - platform, - unique_id, + domain: str, + platform: str, + unique_id: str, *, - suggested_object_id=None, - config_entry=None, - device_id=None, - known_object_ids=None, - disabled_by=None, - ): + suggested_object_id: Optional[str] = None, + config_entry: Optional["ConfigEntry"] = None, + device_id: Optional[str] = None, + known_object_ids: Optional[Iterable[str]] = None, + disabled_by: Optional[str] = None, + ) -> RegistryEntry: """Get entity. Create if it doesn't exist.""" config_entry_id = None if config_entry: @@ -153,7 +156,7 @@ class EntityRegistry: entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: - return self._async_update_entity( + return self._async_update_entity( # type: ignore entity_id, config_entry_id=config_entry_id or _UNDEF, device_id=device_id or _UNDEF, @@ -228,12 +231,15 @@ class EntityRegistry: disabled_by=_UNDEF, ): """Update properties of an entity.""" - return self._async_update_entity( - entity_id, - name=name, - new_entity_id=new_entity_id, - new_unique_id=new_unique_id, - disabled_by=disabled_by, + return cast( # cast until we have _async_update_entity type hinted + RegistryEntry, + self._async_update_entity( + entity_id, + name=name, + new_entity_id=new_entity_id, + new_unique_id=new_unique_id, + disabled_by=disabled_by, + ), ) @callback diff --git a/homeassistant/helpers/logging.py b/homeassistant/helpers/logging.py index 7b2507d9e05..0b274458045 100644 --- a/homeassistant/helpers/logging.py +++ b/homeassistant/helpers/logging.py @@ -1,8 +1,7 @@ """Helpers for logging allowing more advanced logging styles to be used.""" import inspect import logging - -# mypy: allow-untyped-defs, no-check-untyped-defs +from typing import Any, Mapping, MutableMapping, Optional, Tuple class KeywordMessage: @@ -12,13 +11,13 @@ class KeywordMessage: Adapted from: https://stackoverflow.com/a/24683360/2267718 """ - def __init__(self, fmt, args, kwargs): - """Initialize a new BraceMessage object.""" + def __init__(self, fmt: Any, args: Any, kwargs: Mapping[str, Any]) -> None: + """Initialize a new KeywordMessage object.""" self._fmt = fmt self._args = args self._kwargs = kwargs - def __str__(self): + def __str__(self) -> str: """Convert the object to a string for logging.""" return str(self._fmt).format(*self._args, **self._kwargs) @@ -26,26 +25,30 @@ class KeywordMessage: class KeywordStyleAdapter(logging.LoggerAdapter): """Represents an adapter wrapping the logger allowing KeywordMessages.""" - def __init__(self, logger, extra=None): + def __init__( + self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None + ) -> None: """Initialize a new StyleAdapter for the provided logger.""" super().__init__(logger, extra or {}) - def log(self, level, msg, *args, **kwargs): + def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None: """Log the message provided at the appropriate level.""" if self.isEnabledFor(level): msg, log_kwargs = self.process(msg, kwargs) - self.logger._log( # pylint: disable=protected-access + self.logger._log( # type: ignore # pylint: disable=protected-access level, KeywordMessage(msg, args, kwargs), (), **log_kwargs ) - def process(self, msg, kwargs): + def process( + self, msg: Any, kwargs: MutableMapping[str, Any] + ) -> Tuple[Any, MutableMapping[str, Any]]: """Process the keyward args in preparation for logging.""" return ( msg, { k: kwargs[k] for k in inspect.getfullargspec( - self.logger._log # pylint: disable=protected-access + self.logger._log # type: ignore # pylint: disable=protected-access ).args[1:] if k in kwargs },