More helpers type improvements (#30145)

This commit is contained in:
Ville Skyttä 2019-12-22 20:51:39 +02:00 committed by Paulus Schoutsen
parent 70f8bfbd4f
commit 868eb3c735
5 changed files with 124 additions and 90 deletions

View file

@ -1,6 +1,6 @@
"""Helper to check the configuration file."""
from collections import OrderedDict, namedtuple
from typing import List
from collections import OrderedDict
from typing import List, NamedTuple, Optional
import attr
import voluptuous as vol
@ -19,15 +19,20 @@ from homeassistant.config import (
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType
from homeassistant.requirements import (
RequirementsNotFound,
async_get_integration_with_requirements,
)
import homeassistant.util.yaml.loader as yaml_loader
# mypy: allow-untyped-calls, allow-untyped-defs, no-warn-return-any
CheckConfigError = namedtuple("CheckConfigError", "message domain config")
class CheckConfigError(NamedTuple):
"""Configuration check error."""
message: str
domain: Optional[str]
config: Optional[ConfigType]
@attr.s
@ -36,7 +41,12 @@ class HomeAssistantConfig(OrderedDict):
errors: List[CheckConfigError] = attr.ib(default=attr.Factory(list))
def add_error(self, message, domain=None, config=None):
def add_error(
self,
message: str,
domain: Optional[str] = None,
config: Optional[ConfigType] = None,
) -> "HomeAssistantConfig":
"""Add a single error."""
self.errors.append(CheckConfigError(str(message), domain, config))
return self
@ -55,7 +65,9 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
config_dir = hass.config.config_dir
result = HomeAssistantConfig()
def _pack_error(package, component, config, message):
def _pack_error(
package: str, component: str, config: ConfigType, message: str
) -> None:
"""Handle errors from packages: _log_pkg_error."""
message = "Package {} setup failed. Component {} {}".format(
package, component, message
@ -64,7 +76,7 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
pack_config = core_config[CONF_PACKAGES].get(package, config)
result.add_error(message, domain, pack_config)
def _comp_error(ex, domain, config):
def _comp_error(ex: Exception, domain: str, config: ConfigType) -> None:
"""Handle errors from components: async_log_exception."""
result.add_error(_format_config_error(ex, domain, config), domain, config)

View file

@ -5,13 +5,26 @@ from datetime import (
time as time_sys,
timedelta,
)
from enum import Enum
import inspect
import logging
from numbers import Number
import os
import re
from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Hashable,
List,
Optional,
Pattern,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import urlparse
from uuid import UUID
@ -48,12 +61,11 @@ from homeassistant.const import (
)
from homeassistant.core import split_entity_id, valid_entity_id
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template as template_helper
from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import slugify as util_slugify
import homeassistant.util.dt as dt_util
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
# pylint: disable=invalid-name
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
@ -126,7 +138,7 @@ def boolean(value: Any) -> bool:
raise vol.Invalid("invalid boolean value {}".format(value))
def isdevice(value):
def isdevice(value: Any) -> str:
"""Validate that value is a real device."""
try:
os.stat(value)
@ -135,19 +147,19 @@ def isdevice(value):
raise vol.Invalid("No device at {} found".format(value))
def matches_regex(regex):
def matches_regex(regex: str) -> Callable[[Any], str]:
"""Validate that the value is a string that matches a regex."""
regex = re.compile(regex)
compiled = re.compile(regex)
def validator(value: Any) -> str:
"""Validate that value matches the given regex."""
if not isinstance(value, str):
raise vol.Invalid("not a string value: {}".format(value))
if not regex.match(value):
if not compiled.match(value):
raise vol.Invalid(
"value {} does not match regular expression {}".format(
value, regex.pattern
value, compiled.pattern
)
)
@ -156,14 +168,14 @@ def matches_regex(regex):
return validator
def is_regex(value):
def is_regex(value: Any) -> Pattern[Any]:
"""Validate that a string is a valid regular expression."""
try:
r = re.compile(value)
return r
except TypeError:
raise vol.Invalid(
"value {} is of the wrong type for a regular " "expression".format(value)
"value {} is of the wrong type for a regular expression".format(value)
)
except re.error:
raise vol.Invalid("value {} is not a valid regular expression".format(value))
@ -204,9 +216,9 @@ def ensure_list(value: Union[T, List[T], None]) -> List[T]:
def entity_id(value: Any) -> str:
"""Validate Entity ID."""
value = string(value).lower()
if valid_entity_id(value):
return value
str_value = string(value).lower()
if valid_entity_id(str_value):
return str_value
raise vol.Invalid("Entity ID {} is an invalid entity id".format(value))
@ -253,17 +265,17 @@ def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]:
return validate
def enum(enumClass):
def enum(enumClass: Type[Enum]) -> vol.All:
"""Create validator for specified enum."""
return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)
def icon(value):
def icon(value: Any) -> str:
"""Validate icon."""
value = str(value)
str_value = str(value)
if ":" in value:
return value
if ":" in str_value:
return str_value
raise vol.Invalid('Icons should be specified in the form "prefix:name"')
@ -362,7 +374,7 @@ def time_period_seconds(value: Union[int, str]) -> timedelta:
time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
def match_all(value):
def match_all(value: T) -> T:
"""Validate that matches all values."""
return value
@ -382,12 +394,12 @@ def remove_falsy(value: List[T]) -> List[T]:
return [v for v in value if v]
def service(value):
def service(value: Any) -> str:
"""Validate service."""
# Services use same format as entities so we can use same helper.
value = string(value).lower()
if valid_entity_id(value):
return value
str_value = string(value).lower()
if valid_entity_id(str_value):
return str_value
raise vol.Invalid("Service {} does not match format <domain>.<name>".format(value))
@ -407,7 +419,7 @@ def schema_with_slug_keys(value_schema: Union[T, Callable]) -> Callable:
for key in value.keys():
slug(key)
return schema(value)
return cast(Dict, schema(value))
return verify
@ -416,10 +428,10 @@ def slug(value: Any) -> str:
"""Validate value is a valid slug."""
if value is None:
raise vol.Invalid("Slug should not be None")
value = str(value)
slg = util_slugify(value)
if value == slg:
return value
str_value = str(value)
slg = util_slugify(str_value)
if str_value == slg:
return str_value
raise vol.Invalid("invalid slug {} (try {})".format(value, slg))
@ -458,42 +470,41 @@ unit_system = vol.All(
)
def template(value):
def template(value: Optional[Any]) -> template_helper.Template:
"""Validate a jinja2 template."""
from homeassistant.helpers import template as template_helper
if value is None:
raise vol.Invalid("template value is None")
if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string")
value = template_helper.Template(str(value))
template_value = template_helper.Template(str(value)) # type: ignore
try:
value.ensure_valid()
return value
template_value.ensure_valid()
return cast(template_helper.Template, template_value)
except TemplateError as ex:
raise vol.Invalid("invalid template ({})".format(ex))
def template_complex(value):
def template_complex(value: Any) -> Any:
"""Validate a complex jinja2 template."""
if isinstance(value, list):
return_value = value.copy()
for idx, element in enumerate(return_value):
return_value[idx] = template_complex(element)
return return_value
return_list = value.copy()
for idx, element in enumerate(return_list):
return_list[idx] = template_complex(element)
return return_list
if isinstance(value, dict):
return_value = value.copy()
for key, element in return_value.items():
return_value[key] = template_complex(element)
return return_value
return_dict = value.copy()
for key, element in return_dict.items():
return_dict[key] = template_complex(element)
return return_dict
if isinstance(value, str):
return template(value)
return value
def datetime(value):
def datetime(value: Any) -> datetime_sys:
"""Validate datetime."""
if isinstance(value, datetime_sys):
return value
@ -509,7 +520,7 @@ def datetime(value):
return date_val
def time_zone(value):
def time_zone(value: str) -> str:
"""Validate timezone."""
if dt_util.get_time_zone(value) is not None:
return value
@ -522,7 +533,7 @@ def time_zone(value):
weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)])
def socket_timeout(value):
def socket_timeout(value: Optional[Any]) -> object:
"""Validate timeout float > 0.0.
None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object.
@ -544,12 +555,12 @@ def url(value: Any) -> str:
url_in = str(value)
if urlparse(url_in).scheme in ["http", "https"]:
return vol.Schema(vol.Url())(url_in)
return cast(str, vol.Schema(vol.Url())(url_in))
raise vol.Invalid("invalid url")
def x10_address(value):
def x10_address(value: str) -> str:
"""Validate an x10 address."""
regex = re.compile(r"([A-Pa-p]{1})(?:[2-9]|1[0-6]?)$")
if not regex.match(value):
@ -557,7 +568,7 @@ def x10_address(value):
return str(value).lower()
def uuid4_hex(value):
def uuid4_hex(value: Any) -> str:
"""Validate a v4 UUID in hex format."""
try:
result = UUID(value, version=4)
@ -678,10 +689,12 @@ def deprecated(
# Validator helpers
def key_dependency(key, dependency):
def key_dependency(
key: Hashable, dependency: Hashable
) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]:
"""Validate that all dependencies exist for key."""
def validator(value):
def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
"""Test dependencies."""
if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict")
@ -696,7 +709,7 @@ def key_dependency(key, dependency):
return validator
def custom_serializer(schema):
def custom_serializer(schema: Any) -> Any:
"""Serialize additional types for voluptuous_serialize."""
if schema is positive_time_period_dict:
return {"type": "positive_time_period_dict"}

View file

@ -12,8 +12,7 @@ from homeassistant.loader import bind_hass
from .typing import HomeAssistantType
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
_UNDEF = object()
@ -71,10 +70,11 @@ def format_mac(mac: str) -> str:
class DeviceRegistry:
"""Class to hold a registry of devices."""
def __init__(self, hass):
devices: Dict[str, DeviceEntry]
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry."""
self.hass = hass
self.devices = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback

View file

@ -11,7 +11,7 @@ import asyncio
from collections import OrderedDict
from itertools import chain
import logging
from typing import Any, Dict, Iterable, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
import attr
@ -23,6 +23,9 @@ from homeassistant.util.yaml import load_yaml
from .typing import HomeAssistantType
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry # noqa: F401
# mypy: allow-untyped-defs, no-check-untyped-defs
PATH_REGISTRY = "entity_registry.yaml"
@ -48,7 +51,7 @@ class RegistryEntry:
unique_id = attr.ib(type=str)
platform = attr.ib(type=str)
name = attr.ib(type=str, default=None)
device_id = attr.ib(type=str, default=None)
device_id: Optional[str] = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by = attr.ib(
type=Optional[str],
@ -135,16 +138,16 @@ class EntityRegistry:
@callback
def async_get_or_create(
self,
domain,
platform,
unique_id,
domain: str,
platform: str,
unique_id: str,
*,
suggested_object_id=None,
config_entry=None,
device_id=None,
known_object_ids=None,
disabled_by=None,
):
suggested_object_id: Optional[str] = None,
config_entry: Optional["ConfigEntry"] = None,
device_id: Optional[str] = None,
known_object_ids: Optional[Iterable[str]] = None,
disabled_by: Optional[str] = None,
) -> RegistryEntry:
"""Get entity. Create if it doesn't exist."""
config_entry_id = None
if config_entry:
@ -153,7 +156,7 @@ class EntityRegistry:
entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id:
return self._async_update_entity(
return self._async_update_entity( # type: ignore
entity_id,
config_entry_id=config_entry_id or _UNDEF,
device_id=device_id or _UNDEF,
@ -228,12 +231,15 @@ class EntityRegistry:
disabled_by=_UNDEF,
):
"""Update properties of an entity."""
return self._async_update_entity(
entity_id,
name=name,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
return cast( # cast until we have _async_update_entity type hinted
RegistryEntry,
self._async_update_entity(
entity_id,
name=name,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
),
)
@callback

View file

@ -1,8 +1,7 @@
"""Helpers for logging allowing more advanced logging styles to be used."""
import inspect
import logging
# mypy: allow-untyped-defs, no-check-untyped-defs
from typing import Any, Mapping, MutableMapping, Optional, Tuple
class KeywordMessage:
@ -12,13 +11,13 @@ class KeywordMessage:
Adapted from: https://stackoverflow.com/a/24683360/2267718
"""
def __init__(self, fmt, args, kwargs):
"""Initialize a new BraceMessage object."""
def __init__(self, fmt: Any, args: Any, kwargs: Mapping[str, Any]) -> None:
"""Initialize a new KeywordMessage object."""
self._fmt = fmt
self._args = args
self._kwargs = kwargs
def __str__(self):
def __str__(self) -> str:
"""Convert the object to a string for logging."""
return str(self._fmt).format(*self._args, **self._kwargs)
@ -26,26 +25,30 @@ class KeywordMessage:
class KeywordStyleAdapter(logging.LoggerAdapter):
"""Represents an adapter wrapping the logger allowing KeywordMessages."""
def __init__(self, logger, extra=None):
def __init__(
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None
) -> None:
"""Initialize a new StyleAdapter for the provided logger."""
super().__init__(logger, extra or {})
def log(self, level, msg, *args, **kwargs):
def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Log the message provided at the appropriate level."""
if self.isEnabledFor(level):
msg, log_kwargs = self.process(msg, kwargs)
self.logger._log( # pylint: disable=protected-access
self.logger._log( # type: ignore # pylint: disable=protected-access
level, KeywordMessage(msg, args, kwargs), (), **log_kwargs
)
def process(self, msg, kwargs):
def process(
self, msg: Any, kwargs: MutableMapping[str, Any]
) -> Tuple[Any, MutableMapping[str, Any]]:
"""Process the keyward args in preparation for logging."""
return (
msg,
{
k: kwargs[k]
for k in inspect.getfullargspec(
self.logger._log # pylint: disable=protected-access
self.logger._log # type: ignore # pylint: disable=protected-access
).args[1:]
if k in kwargs
},