More helpers type improvements (#30145)
This commit is contained in:
parent
70f8bfbd4f
commit
868eb3c735
5 changed files with 124 additions and 90 deletions
|
@ -1,6 +1,6 @@
|
|||
"""Helper to check the configuration file."""
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import List
|
||||
from collections import OrderedDict
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
|
@ -19,15 +19,20 @@ from homeassistant.config import (
|
|||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.requirements import (
|
||||
RequirementsNotFound,
|
||||
async_get_integration_with_requirements,
|
||||
)
|
||||
import homeassistant.util.yaml.loader as yaml_loader
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-warn-return-any
|
||||
|
||||
CheckConfigError = namedtuple("CheckConfigError", "message domain config")
|
||||
class CheckConfigError(NamedTuple):
|
||||
"""Configuration check error."""
|
||||
|
||||
message: str
|
||||
domain: Optional[str]
|
||||
config: Optional[ConfigType]
|
||||
|
||||
|
||||
@attr.s
|
||||
|
@ -36,7 +41,12 @@ class HomeAssistantConfig(OrderedDict):
|
|||
|
||||
errors: List[CheckConfigError] = attr.ib(default=attr.Factory(list))
|
||||
|
||||
def add_error(self, message, domain=None, config=None):
|
||||
def add_error(
|
||||
self,
|
||||
message: str,
|
||||
domain: Optional[str] = None,
|
||||
config: Optional[ConfigType] = None,
|
||||
) -> "HomeAssistantConfig":
|
||||
"""Add a single error."""
|
||||
self.errors.append(CheckConfigError(str(message), domain, config))
|
||||
return self
|
||||
|
@ -55,7 +65,9 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
|
|||
config_dir = hass.config.config_dir
|
||||
result = HomeAssistantConfig()
|
||||
|
||||
def _pack_error(package, component, config, message):
|
||||
def _pack_error(
|
||||
package: str, component: str, config: ConfigType, message: str
|
||||
) -> None:
|
||||
"""Handle errors from packages: _log_pkg_error."""
|
||||
message = "Package {} setup failed. Component {} {}".format(
|
||||
package, component, message
|
||||
|
@ -64,7 +76,7 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
|
|||
pack_config = core_config[CONF_PACKAGES].get(package, config)
|
||||
result.add_error(message, domain, pack_config)
|
||||
|
||||
def _comp_error(ex, domain, config):
|
||||
def _comp_error(ex: Exception, domain: str, config: ConfigType) -> None:
|
||||
"""Handle errors from components: async_log_exception."""
|
||||
result.add_error(_format_config_error(ex, domain, config), domain, config)
|
||||
|
||||
|
|
|
@ -5,13 +5,26 @@ from datetime import (
|
|||
time as time_sys,
|
||||
timedelta,
|
||||
)
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import logging
|
||||
from numbers import Number
|
||||
import os
|
||||
import re
|
||||
from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
||||
|
@ -48,12 +61,11 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import split_entity_id, valid_entity_id
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import template as template_helper
|
||||
from homeassistant.helpers.logging import KeywordStyleAdapter
|
||||
from homeassistant.util import slugify as util_slugify
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
|
||||
|
@ -126,7 +138,7 @@ def boolean(value: Any) -> bool:
|
|||
raise vol.Invalid("invalid boolean value {}".format(value))
|
||||
|
||||
|
||||
def isdevice(value):
|
||||
def isdevice(value: Any) -> str:
|
||||
"""Validate that value is a real device."""
|
||||
try:
|
||||
os.stat(value)
|
||||
|
@ -135,19 +147,19 @@ def isdevice(value):
|
|||
raise vol.Invalid("No device at {} found".format(value))
|
||||
|
||||
|
||||
def matches_regex(regex):
|
||||
def matches_regex(regex: str) -> Callable[[Any], str]:
|
||||
"""Validate that the value is a string that matches a regex."""
|
||||
regex = re.compile(regex)
|
||||
compiled = re.compile(regex)
|
||||
|
||||
def validator(value: Any) -> str:
|
||||
"""Validate that value matches the given regex."""
|
||||
if not isinstance(value, str):
|
||||
raise vol.Invalid("not a string value: {}".format(value))
|
||||
|
||||
if not regex.match(value):
|
||||
if not compiled.match(value):
|
||||
raise vol.Invalid(
|
||||
"value {} does not match regular expression {}".format(
|
||||
value, regex.pattern
|
||||
value, compiled.pattern
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -156,14 +168,14 @@ def matches_regex(regex):
|
|||
return validator
|
||||
|
||||
|
||||
def is_regex(value):
|
||||
def is_regex(value: Any) -> Pattern[Any]:
|
||||
"""Validate that a string is a valid regular expression."""
|
||||
try:
|
||||
r = re.compile(value)
|
||||
return r
|
||||
except TypeError:
|
||||
raise vol.Invalid(
|
||||
"value {} is of the wrong type for a regular " "expression".format(value)
|
||||
"value {} is of the wrong type for a regular expression".format(value)
|
||||
)
|
||||
except re.error:
|
||||
raise vol.Invalid("value {} is not a valid regular expression".format(value))
|
||||
|
@ -204,9 +216,9 @@ def ensure_list(value: Union[T, List[T], None]) -> List[T]:
|
|||
|
||||
def entity_id(value: Any) -> str:
|
||||
"""Validate Entity ID."""
|
||||
value = string(value).lower()
|
||||
if valid_entity_id(value):
|
||||
return value
|
||||
str_value = string(value).lower()
|
||||
if valid_entity_id(str_value):
|
||||
return str_value
|
||||
|
||||
raise vol.Invalid("Entity ID {} is an invalid entity id".format(value))
|
||||
|
||||
|
@ -253,17 +265,17 @@ def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]:
|
|||
return validate
|
||||
|
||||
|
||||
def enum(enumClass):
|
||||
def enum(enumClass: Type[Enum]) -> vol.All:
|
||||
"""Create validator for specified enum."""
|
||||
return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)
|
||||
|
||||
|
||||
def icon(value):
|
||||
def icon(value: Any) -> str:
|
||||
"""Validate icon."""
|
||||
value = str(value)
|
||||
str_value = str(value)
|
||||
|
||||
if ":" in value:
|
||||
return value
|
||||
if ":" in str_value:
|
||||
return str_value
|
||||
|
||||
raise vol.Invalid('Icons should be specified in the form "prefix:name"')
|
||||
|
||||
|
@ -362,7 +374,7 @@ def time_period_seconds(value: Union[int, str]) -> timedelta:
|
|||
time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
|
||||
|
||||
|
||||
def match_all(value):
|
||||
def match_all(value: T) -> T:
|
||||
"""Validate that matches all values."""
|
||||
return value
|
||||
|
||||
|
@ -382,12 +394,12 @@ def remove_falsy(value: List[T]) -> List[T]:
|
|||
return [v for v in value if v]
|
||||
|
||||
|
||||
def service(value):
|
||||
def service(value: Any) -> str:
|
||||
"""Validate service."""
|
||||
# Services use same format as entities so we can use same helper.
|
||||
value = string(value).lower()
|
||||
if valid_entity_id(value):
|
||||
return value
|
||||
str_value = string(value).lower()
|
||||
if valid_entity_id(str_value):
|
||||
return str_value
|
||||
raise vol.Invalid("Service {} does not match format <domain>.<name>".format(value))
|
||||
|
||||
|
||||
|
@ -407,7 +419,7 @@ def schema_with_slug_keys(value_schema: Union[T, Callable]) -> Callable:
|
|||
for key in value.keys():
|
||||
slug(key)
|
||||
|
||||
return schema(value)
|
||||
return cast(Dict, schema(value))
|
||||
|
||||
return verify
|
||||
|
||||
|
@ -416,10 +428,10 @@ def slug(value: Any) -> str:
|
|||
"""Validate value is a valid slug."""
|
||||
if value is None:
|
||||
raise vol.Invalid("Slug should not be None")
|
||||
value = str(value)
|
||||
slg = util_slugify(value)
|
||||
if value == slg:
|
||||
return value
|
||||
str_value = str(value)
|
||||
slg = util_slugify(str_value)
|
||||
if str_value == slg:
|
||||
return str_value
|
||||
raise vol.Invalid("invalid slug {} (try {})".format(value, slg))
|
||||
|
||||
|
||||
|
@ -458,42 +470,41 @@ unit_system = vol.All(
|
|||
)
|
||||
|
||||
|
||||
def template(value):
|
||||
def template(value: Optional[Any]) -> template_helper.Template:
|
||||
"""Validate a jinja2 template."""
|
||||
from homeassistant.helpers import template as template_helper
|
||||
|
||||
if value is None:
|
||||
raise vol.Invalid("template value is None")
|
||||
if isinstance(value, (list, dict, template_helper.Template)):
|
||||
raise vol.Invalid("template value should be a string")
|
||||
|
||||
value = template_helper.Template(str(value))
|
||||
template_value = template_helper.Template(str(value)) # type: ignore
|
||||
|
||||
try:
|
||||
value.ensure_valid()
|
||||
return value
|
||||
template_value.ensure_valid()
|
||||
return cast(template_helper.Template, template_value)
|
||||
except TemplateError as ex:
|
||||
raise vol.Invalid("invalid template ({})".format(ex))
|
||||
|
||||
|
||||
def template_complex(value):
|
||||
def template_complex(value: Any) -> Any:
|
||||
"""Validate a complex jinja2 template."""
|
||||
if isinstance(value, list):
|
||||
return_value = value.copy()
|
||||
for idx, element in enumerate(return_value):
|
||||
return_value[idx] = template_complex(element)
|
||||
return return_value
|
||||
return_list = value.copy()
|
||||
for idx, element in enumerate(return_list):
|
||||
return_list[idx] = template_complex(element)
|
||||
return return_list
|
||||
if isinstance(value, dict):
|
||||
return_value = value.copy()
|
||||
for key, element in return_value.items():
|
||||
return_value[key] = template_complex(element)
|
||||
return return_value
|
||||
return_dict = value.copy()
|
||||
for key, element in return_dict.items():
|
||||
return_dict[key] = template_complex(element)
|
||||
return return_dict
|
||||
if isinstance(value, str):
|
||||
return template(value)
|
||||
return value
|
||||
|
||||
|
||||
def datetime(value):
|
||||
def datetime(value: Any) -> datetime_sys:
|
||||
"""Validate datetime."""
|
||||
if isinstance(value, datetime_sys):
|
||||
return value
|
||||
|
@ -509,7 +520,7 @@ def datetime(value):
|
|||
return date_val
|
||||
|
||||
|
||||
def time_zone(value):
|
||||
def time_zone(value: str) -> str:
|
||||
"""Validate timezone."""
|
||||
if dt_util.get_time_zone(value) is not None:
|
||||
return value
|
||||
|
@ -522,7 +533,7 @@ def time_zone(value):
|
|||
weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)])
|
||||
|
||||
|
||||
def socket_timeout(value):
|
||||
def socket_timeout(value: Optional[Any]) -> object:
|
||||
"""Validate timeout float > 0.0.
|
||||
|
||||
None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object.
|
||||
|
@ -544,12 +555,12 @@ def url(value: Any) -> str:
|
|||
url_in = str(value)
|
||||
|
||||
if urlparse(url_in).scheme in ["http", "https"]:
|
||||
return vol.Schema(vol.Url())(url_in)
|
||||
return cast(str, vol.Schema(vol.Url())(url_in))
|
||||
|
||||
raise vol.Invalid("invalid url")
|
||||
|
||||
|
||||
def x10_address(value):
|
||||
def x10_address(value: str) -> str:
|
||||
"""Validate an x10 address."""
|
||||
regex = re.compile(r"([A-Pa-p]{1})(?:[2-9]|1[0-6]?)$")
|
||||
if not regex.match(value):
|
||||
|
@ -557,7 +568,7 @@ def x10_address(value):
|
|||
return str(value).lower()
|
||||
|
||||
|
||||
def uuid4_hex(value):
|
||||
def uuid4_hex(value: Any) -> str:
|
||||
"""Validate a v4 UUID in hex format."""
|
||||
try:
|
||||
result = UUID(value, version=4)
|
||||
|
@ -678,10 +689,12 @@ def deprecated(
|
|||
# Validator helpers
|
||||
|
||||
|
||||
def key_dependency(key, dependency):
|
||||
def key_dependency(
|
||||
key: Hashable, dependency: Hashable
|
||||
) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]:
|
||||
"""Validate that all dependencies exist for key."""
|
||||
|
||||
def validator(value):
|
||||
def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
|
||||
"""Test dependencies."""
|
||||
if not isinstance(value, dict):
|
||||
raise vol.Invalid("key dependencies require a dict")
|
||||
|
@ -696,7 +709,7 @@ def key_dependency(key, dependency):
|
|||
return validator
|
||||
|
||||
|
||||
def custom_serializer(schema):
|
||||
def custom_serializer(schema: Any) -> Any:
|
||||
"""Serialize additional types for voluptuous_serialize."""
|
||||
if schema is positive_time_period_dict:
|
||||
return {"type": "positive_time_period_dict"}
|
||||
|
|
|
@ -12,8 +12,7 @@ from homeassistant.loader import bind_hass
|
|||
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_UNDEF = object()
|
||||
|
@ -71,10 +70,11 @@ def format_mac(mac: str) -> str:
|
|||
class DeviceRegistry:
|
||||
"""Class to hold a registry of devices."""
|
||||
|
||||
def __init__(self, hass):
|
||||
devices: Dict[str, DeviceEntry]
|
||||
|
||||
def __init__(self, hass: HomeAssistantType) -> None:
|
||||
"""Initialize the device registry."""
|
||||
self.hass = hass
|
||||
self.devices = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
|
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -23,6 +23,9 @@ from homeassistant.util.yaml import load_yaml
|
|||
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.config_entries import ConfigEntry # noqa: F401
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
PATH_REGISTRY = "entity_registry.yaml"
|
||||
|
@ -48,7 +51,7 @@ class RegistryEntry:
|
|||
unique_id = attr.ib(type=str)
|
||||
platform = attr.ib(type=str)
|
||||
name = attr.ib(type=str, default=None)
|
||||
device_id = attr.ib(type=str, default=None)
|
||||
device_id: Optional[str] = attr.ib(default=None)
|
||||
config_entry_id: Optional[str] = attr.ib(default=None)
|
||||
disabled_by = attr.ib(
|
||||
type=Optional[str],
|
||||
|
@ -135,16 +138,16 @@ class EntityRegistry:
|
|||
@callback
|
||||
def async_get_or_create(
|
||||
self,
|
||||
domain,
|
||||
platform,
|
||||
unique_id,
|
||||
domain: str,
|
||||
platform: str,
|
||||
unique_id: str,
|
||||
*,
|
||||
suggested_object_id=None,
|
||||
config_entry=None,
|
||||
device_id=None,
|
||||
known_object_ids=None,
|
||||
disabled_by=None,
|
||||
):
|
||||
suggested_object_id: Optional[str] = None,
|
||||
config_entry: Optional["ConfigEntry"] = None,
|
||||
device_id: Optional[str] = None,
|
||||
known_object_ids: Optional[Iterable[str]] = None,
|
||||
disabled_by: Optional[str] = None,
|
||||
) -> RegistryEntry:
|
||||
"""Get entity. Create if it doesn't exist."""
|
||||
config_entry_id = None
|
||||
if config_entry:
|
||||
|
@ -153,7 +156,7 @@ class EntityRegistry:
|
|||
entity_id = self.async_get_entity_id(domain, platform, unique_id)
|
||||
|
||||
if entity_id:
|
||||
return self._async_update_entity(
|
||||
return self._async_update_entity( # type: ignore
|
||||
entity_id,
|
||||
config_entry_id=config_entry_id or _UNDEF,
|
||||
device_id=device_id or _UNDEF,
|
||||
|
@ -228,12 +231,15 @@ class EntityRegistry:
|
|||
disabled_by=_UNDEF,
|
||||
):
|
||||
"""Update properties of an entity."""
|
||||
return self._async_update_entity(
|
||||
entity_id,
|
||||
name=name,
|
||||
new_entity_id=new_entity_id,
|
||||
new_unique_id=new_unique_id,
|
||||
disabled_by=disabled_by,
|
||||
return cast( # cast until we have _async_update_entity type hinted
|
||||
RegistryEntry,
|
||||
self._async_update_entity(
|
||||
entity_id,
|
||||
name=name,
|
||||
new_entity_id=new_entity_id,
|
||||
new_unique_id=new_unique_id,
|
||||
disabled_by=disabled_by,
|
||||
),
|
||||
)
|
||||
|
||||
@callback
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
"""Helpers for logging allowing more advanced logging styles to be used."""
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
from typing import Any, Mapping, MutableMapping, Optional, Tuple
|
||||
|
||||
|
||||
class KeywordMessage:
|
||||
|
@ -12,13 +11,13 @@ class KeywordMessage:
|
|||
Adapted from: https://stackoverflow.com/a/24683360/2267718
|
||||
"""
|
||||
|
||||
def __init__(self, fmt, args, kwargs):
|
||||
"""Initialize a new BraceMessage object."""
|
||||
def __init__(self, fmt: Any, args: Any, kwargs: Mapping[str, Any]) -> None:
|
||||
"""Initialize a new KeywordMessage object."""
|
||||
self._fmt = fmt
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Convert the object to a string for logging."""
|
||||
return str(self._fmt).format(*self._args, **self._kwargs)
|
||||
|
||||
|
@ -26,26 +25,30 @@ class KeywordMessage:
|
|||
class KeywordStyleAdapter(logging.LoggerAdapter):
|
||||
"""Represents an adapter wrapping the logger allowing KeywordMessages."""
|
||||
|
||||
def __init__(self, logger, extra=None):
|
||||
def __init__(
|
||||
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None
|
||||
) -> None:
|
||||
"""Initialize a new StyleAdapter for the provided logger."""
|
||||
super().__init__(logger, extra or {})
|
||||
|
||||
def log(self, level, msg, *args, **kwargs):
|
||||
def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
|
||||
"""Log the message provided at the appropriate level."""
|
||||
if self.isEnabledFor(level):
|
||||
msg, log_kwargs = self.process(msg, kwargs)
|
||||
self.logger._log( # pylint: disable=protected-access
|
||||
self.logger._log( # type: ignore # pylint: disable=protected-access
|
||||
level, KeywordMessage(msg, args, kwargs), (), **log_kwargs
|
||||
)
|
||||
|
||||
def process(self, msg, kwargs):
|
||||
def process(
|
||||
self, msg: Any, kwargs: MutableMapping[str, Any]
|
||||
) -> Tuple[Any, MutableMapping[str, Any]]:
|
||||
"""Process the keyward args in preparation for logging."""
|
||||
return (
|
||||
msg,
|
||||
{
|
||||
k: kwargs[k]
|
||||
for k in inspect.getfullargspec(
|
||||
self.logger._log # pylint: disable=protected-access
|
||||
self.logger._log # type: ignore # pylint: disable=protected-access
|
||||
).args[1:]
|
||||
if k in kwargs
|
||||
},
|
||||
|
|
Loading…
Add table
Reference in a new issue