Generics and other type hint improvements (#45250)

This commit is contained in:
Ville Skyttä 2021-01-18 23:23:25 +02:00 committed by GitHub
parent 4928476abe
commit 94dbcc9d2b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 92 additions and 45 deletions

View file

@ -571,7 +571,7 @@ class Context:
parent_id: Optional[str] = attr.ib(default=None) parent_id: Optional[str] = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex) id: str = attr.ib(factory=uuid_util.random_uuid_hex)
def as_dict(self) -> dict: def as_dict(self) -> Dict[str, Optional[str]]:
"""Return a dictionary representation of the context.""" """Return a dictionary representation of the context."""
return {"id": self.id, "parent_id": self.parent_id, "user_id": self.user_id} return {"id": self.id, "parent_id": self.parent_id, "user_id": self.user_id}
@ -612,7 +612,7 @@ class Event:
# The only event type that shares context are the TIME_CHANGED # The only event type that shares context are the TIME_CHANGED
return hash((self.event_type, self.context.id, self.time_fired)) return hash((self.event_type, self.context.id, self.time_fired))
def as_dict(self) -> Dict: def as_dict(self) -> Dict[str, Any]:
"""Create a dict representation of this Event. """Create a dict representation of this Event.
Async friendly. Async friendly.
@ -682,7 +682,7 @@ class EventBus:
def async_fire( def async_fire(
self, self,
event_type: str, event_type: str,
event_data: Optional[Dict] = None, event_data: Optional[Dict[str, Any]] = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
context: Optional[Context] = None, context: Optional[Context] = None,
time_fired: Optional[datetime.datetime] = None, time_fired: Optional[datetime.datetime] = None,
@ -844,7 +844,7 @@ class State:
self, self,
entity_id: str, entity_id: str,
state: str, state: str,
attributes: Optional[Mapping] = None, attributes: Optional[Mapping[str, Any]] = None,
last_changed: Optional[datetime.datetime] = None, last_changed: Optional[datetime.datetime] = None,
last_updated: Optional[datetime.datetime] = None, last_updated: Optional[datetime.datetime] = None,
context: Optional[Context] = None, context: Optional[Context] = None,
@ -1091,7 +1091,7 @@ class StateMachine:
self, self,
entity_id: str, entity_id: str,
new_state: str, new_state: str,
attributes: Optional[Dict] = None, attributes: Optional[Mapping[str, Any]] = None,
force_update: bool = False, force_update: bool = False,
context: Optional[Context] = None, context: Optional[Context] = None,
) -> None: ) -> None:
@ -1140,7 +1140,7 @@ class StateMachine:
self, self,
entity_id: str, entity_id: str,
new_state: str, new_state: str,
attributes: Optional[Dict] = None, attributes: Optional[Mapping[str, Any]] = None,
force_update: bool = False, force_update: bool = False,
context: Optional[Context] = None, context: Optional[Context] = None,
) -> None: ) -> None:

View file

@ -6,18 +6,20 @@ from typing import Any, Dict, Optional, Pattern
from homeassistant.core import split_entity_id from homeassistant.core import split_entity_id
# mypy: disallow-any-generics
class EntityValues: class EntityValues:
"""Class to store entity id based values.""" """Class to store entity id based values."""
def __init__( def __init__(
self, self,
exact: Optional[Dict] = None, exact: Optional[Dict[str, Dict[str, str]]] = None,
domain: Optional[Dict] = None, domain: Optional[Dict[str, Dict[str, str]]] = None,
glob: Optional[Dict] = None, glob: Optional[Dict[str, Dict[str, str]]] = None,
) -> None: ) -> None:
"""Initialize an EntityConfigDict.""" """Initialize an EntityConfigDict."""
self._cache: Dict[str, Dict] = {} self._cache: Dict[str, Dict[str, str]] = {}
self._exact = exact self._exact = exact
self._domain = domain self._domain = domain
@ -30,7 +32,7 @@ class EntityValues:
self._glob = compiled self._glob = compiled
def get(self, entity_id: str) -> Dict: def get(self, entity_id: str) -> Dict[str, str]:
"""Get config for an entity id.""" """Get config for an entity id."""
if entity_id in self._cache: if entity_id in self._cache:
return self._cache[entity_id] return self._cache[entity_id]

View file

@ -20,6 +20,7 @@ from typing import (
List, List,
Optional, Optional,
Set, Set,
TypedDict,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -34,7 +35,11 @@ from homeassistant.generated.zeroconf import HOMEKIT, ZEROCONF
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name # mypy: disallow-any-generics
CALLABLE_T = TypeVar( # pylint: disable=invalid-name
"CALLABLE_T", bound=Callable[..., Any]
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -54,12 +59,38 @@ _UNDEF = object() # Internal; not helpers.typing.UNDEFINED due to circular depe
MAX_LOAD_CONCURRENTLY = 4 MAX_LOAD_CONCURRENTLY = 4
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Dict: class Manifest(TypedDict, total=False):
"""
Integration manifest.
Note that none of the attributes are marked Optional here. However, some of them may be optional in manifest.json
in the sense that they can be omitted altogether. But when present, they should not have null values in it.
"""
name: str
disabled: str
domain: str
dependencies: List[str]
after_dependencies: List[str]
requirements: List[str]
config_flow: bool
documentation: str
issue_tracker: str
quality_scale: str
mqtt: List[str]
ssdp: List[Dict[str, str]]
zeroconf: List[Union[str, Dict[str, str]]]
dhcp: List[Dict[str, str]]
homekit: Dict[str, List[str]]
is_built_in: bool
codeowners: List[str]
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest:
"""Generate a manifest from a legacy module.""" """Generate a manifest from a legacy module."""
return { return {
"domain": domain, "domain": domain,
"name": domain, "name": domain,
"documentation": None,
"requirements": getattr(module, "REQUIREMENTS", []), "requirements": getattr(module, "REQUIREMENTS", []),
"dependencies": getattr(module, "DEPENDENCIES", []), "dependencies": getattr(module, "DEPENDENCIES", []),
"codeowners": [], "codeowners": [],
@ -205,10 +236,10 @@ async def async_get_homekit(hass: "HomeAssistant") -> Dict[str, str]:
return homekit return homekit
async def async_get_ssdp(hass: "HomeAssistant") -> Dict[str, List]: async def async_get_ssdp(hass: "HomeAssistant") -> Dict[str, List[Dict[str, str]]]:
"""Return cached list of ssdp mappings.""" """Return cached list of ssdp mappings."""
ssdp: Dict[str, List] = SSDP.copy() ssdp: Dict[str, List[Dict[str, str]]] = SSDP.copy()
integrations = await async_get_custom_components(hass) integrations = await async_get_custom_components(hass)
for integration in integrations.values(): for integration in integrations.values():
@ -220,10 +251,10 @@ async def async_get_ssdp(hass: "HomeAssistant") -> Dict[str, List]:
return ssdp return ssdp
async def async_get_mqtt(hass: "HomeAssistant") -> Dict[str, List]: async def async_get_mqtt(hass: "HomeAssistant") -> Dict[str, List[str]]:
"""Return cached list of MQTT mappings.""" """Return cached list of MQTT mappings."""
mqtt: Dict[str, List] = MQTT.copy() mqtt: Dict[str, List[str]] = MQTT.copy()
integrations = await async_get_custom_components(hass) integrations = await async_get_custom_components(hass)
for integration in integrations.values(): for integration in integrations.values():
@ -288,7 +319,7 @@ class Integration:
hass: "HomeAssistant", hass: "HomeAssistant",
pkg_path: str, pkg_path: str,
file_path: pathlib.Path, file_path: pathlib.Path,
manifest: Dict[str, Any], manifest: Manifest,
): ):
"""Initialize an integration.""" """Initialize an integration."""
self.hass = hass self.hass = hass
@ -309,77 +340,77 @@ class Integration:
@property @property
def name(self) -> str: def name(self) -> str:
"""Return name.""" """Return name."""
return cast(str, self.manifest["name"]) return self.manifest["name"]
@property @property
def disabled(self) -> Optional[str]: def disabled(self) -> Optional[str]:
"""Return reason integration is disabled.""" """Return reason integration is disabled."""
return cast(Optional[str], self.manifest.get("disabled")) return self.manifest.get("disabled")
@property @property
def domain(self) -> str: def domain(self) -> str:
"""Return domain.""" """Return domain."""
return cast(str, self.manifest["domain"]) return self.manifest["domain"]
@property @property
def dependencies(self) -> List[str]: def dependencies(self) -> List[str]:
"""Return dependencies.""" """Return dependencies."""
return cast(List[str], self.manifest.get("dependencies", [])) return self.manifest.get("dependencies", [])
@property @property
def after_dependencies(self) -> List[str]: def after_dependencies(self) -> List[str]:
"""Return after_dependencies.""" """Return after_dependencies."""
return cast(List[str], self.manifest.get("after_dependencies", [])) return self.manifest.get("after_dependencies", [])
@property @property
def requirements(self) -> List[str]: def requirements(self) -> List[str]:
"""Return requirements.""" """Return requirements."""
return cast(List[str], self.manifest.get("requirements", [])) return self.manifest.get("requirements", [])
@property @property
def config_flow(self) -> bool: def config_flow(self) -> bool:
"""Return config_flow.""" """Return config_flow."""
return cast(bool, self.manifest.get("config_flow", False)) return self.manifest.get("config_flow") or False
@property @property
def documentation(self) -> Optional[str]: def documentation(self) -> Optional[str]:
"""Return documentation.""" """Return documentation."""
return cast(str, self.manifest.get("documentation")) return self.manifest.get("documentation")
@property @property
def issue_tracker(self) -> Optional[str]: def issue_tracker(self) -> Optional[str]:
"""Return issue tracker link.""" """Return issue tracker link."""
return cast(str, self.manifest.get("issue_tracker")) return self.manifest.get("issue_tracker")
@property @property
def quality_scale(self) -> Optional[str]: def quality_scale(self) -> Optional[str]:
"""Return Integration Quality Scale.""" """Return Integration Quality Scale."""
return cast(str, self.manifest.get("quality_scale")) return self.manifest.get("quality_scale")
@property @property
def mqtt(self) -> Optional[list]: def mqtt(self) -> Optional[List[str]]:
"""Return Integration MQTT entries.""" """Return Integration MQTT entries."""
return cast(List[dict], self.manifest.get("mqtt")) return self.manifest.get("mqtt")
@property @property
def ssdp(self) -> Optional[list]: def ssdp(self) -> Optional[List[Dict[str, str]]]:
"""Return Integration SSDP entries.""" """Return Integration SSDP entries."""
return cast(List[dict], self.manifest.get("ssdp")) return self.manifest.get("ssdp")
@property @property
def zeroconf(self) -> Optional[list]: def zeroconf(self) -> Optional[List[Union[str, Dict[str, str]]]]:
"""Return Integration zeroconf entries.""" """Return Integration zeroconf entries."""
return cast(List[str], self.manifest.get("zeroconf")) return self.manifest.get("zeroconf")
@property @property
def dhcp(self) -> Optional[list]: def dhcp(self) -> Optional[List[Dict[str, str]]]:
"""Return Integration dhcp entries.""" """Return Integration dhcp entries."""
return cast(List[str], self.manifest.get("dhcp")) return self.manifest.get("dhcp")
@property @property
def homekit(self) -> Optional[dict]: def homekit(self) -> Optional[Dict[str, List[str]]]:
"""Return Integration homekit entries.""" """Return Integration homekit entries."""
return cast(Dict[str, List], self.manifest.get("homekit")) return self.manifest.get("homekit")
@property @property
def is_built_in(self) -> bool: def is_built_in(self) -> bool:

View file

@ -9,6 +9,8 @@ from homeassistant.helpers.typing import UNDEFINED, UndefinedType
from homeassistant.loader import Integration, IntegrationNotFound, async_get_integration from homeassistant.loader import Integration, IntegrationNotFound, async_get_integration
import homeassistant.util.package as pkg_util import homeassistant.util.package as pkg_util
# mypy: disallow-any-generics
DATA_PIP_LOCK = "pip_lock" DATA_PIP_LOCK = "pip_lock"
DATA_PKG_CACHE = "pkg_cache" DATA_PKG_CACHE = "pkg_cache"
DATA_INTEGRATIONS_WITH_REQS = "integrations_with_reqs" DATA_INTEGRATIONS_WITH_REQS = "integrations_with_reqs"
@ -24,7 +26,7 @@ DISCOVERY_INTEGRATIONS: Dict[str, Iterable[str]] = {
class RequirementsNotFound(HomeAssistantError): class RequirementsNotFound(HomeAssistantError):
"""Raised when a component is not found.""" """Raised when a component is not found."""
def __init__(self, domain: str, requirements: List) -> None: def __init__(self, domain: str, requirements: List[str]) -> None:
"""Initialize a component not found error.""" """Initialize a component not found error."""
super().__init__(f"Requirements for {domain} not found: {requirements}.") super().__init__(f"Requirements for {domain} not found: {requirements}.")
self.domain = domain self.domain = domain
@ -124,7 +126,7 @@ async def async_process_requirements(
if pkg_util.is_installed(req): if pkg_util.is_installed(req):
continue continue
def _install(req: str, kwargs: Dict) -> bool: def _install(req: str, kwargs: Dict[str, Any]) -> bool:
"""Install requirement.""" """Install requirement."""
return pkg_util.install_package(req, **kwargs) return pkg_util.install_package(req, **kwargs)

View file

@ -9,6 +9,8 @@ from homeassistant import bootstrap
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.frame import warn_use from homeassistant.helpers.frame import warn_use
# mypy: disallow-any-generics
# #
# Python 3.8 has significantly less workers by default # Python 3.8 has significantly less workers by default
# than Python 3.7. In order to be consistent between # than Python 3.7. In order to be consistent between
@ -81,7 +83,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[valid
@callback @callback
def _async_loop_exception_handler(_: Any, context: Dict) -> None: def _async_loop_exception_handler(_: Any, context: Dict[str, Any]) -> None:
"""Handle all exception inside the core loop.""" """Handle all exception inside the core loop."""
kwargs = {} kwargs = {}
exception = context.get("exception") exception = context.get("exception")

View file

@ -1,9 +1,17 @@
"""Util to handle processes.""" """Util to handle processes."""
from __future__ import annotations
import subprocess import subprocess
from typing import Any
# mypy: disallow-any-generics
def kill_subprocess(process: subprocess.Popen) -> None: def kill_subprocess(
# pylint: disable=unsubscriptable-object # https://github.com/PyCQA/pylint/issues/4034
process: subprocess.Popen[Any],
) -> None:
"""Force kill a subprocess and wait for it to exit.""" """Force kill a subprocess and wait for it to exit."""
process.kill() process.kill()
process.communicate() process.communicate()

View file

@ -1,6 +1,6 @@
"""Unit system helper class and methods.""" """Unit system helper class and methods."""
from numbers import Number from numbers import Number
from typing import Optional from typing import Dict, Optional
from homeassistant.const import ( from homeassistant.const import (
CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_IMPERIAL,
@ -31,6 +31,8 @@ from homeassistant.util import (
volume as volume_util, volume as volume_util,
) )
# mypy: disallow-any-generics
LENGTH_UNITS = distance_util.VALID_UNITS LENGTH_UNITS = distance_util.VALID_UNITS
MASS_UNITS = [MASS_POUNDS, MASS_OUNCES, MASS_KILOGRAMS, MASS_GRAMS] MASS_UNITS = [MASS_POUNDS, MASS_OUNCES, MASS_KILOGRAMS, MASS_GRAMS]
@ -135,7 +137,7 @@ class UnitSystem:
# type ignore: https://github.com/python/mypy/issues/7207 # type ignore: https://github.com/python/mypy/issues/7207
return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore
def as_dict(self) -> dict: def as_dict(self) -> Dict[str, str]:
"""Convert the unit system to a dictionary.""" """Convert the unit system to a dictionary."""
return { return {
LENGTH: self.length_unit, LENGTH: self.length_unit,