Add HassDict implementation (#103844)

This commit is contained in:
Marc Mueller 2024-05-07 10:53:13 +02:00 committed by GitHub
parent fd52588565
commit 3d700e2b71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 287 additions and 36 deletions

View file

@ -2035,9 +2035,7 @@ class ConfigEntries:
Config entries which are created after Home Assistant is started can't be waited
for, the function will just return if the config entry is loaded or not.
"""
setup_done: dict[str, asyncio.Future[bool]] = self.hass.data.get(
DATA_SETUP_DONE, {}
)
setup_done = self.hass.data.get(DATA_SETUP_DONE, {})
if setup_future := setup_done.get(entry.domain):
await setup_future
# The component was not loaded.

View file

@ -104,6 +104,7 @@ from .util.async_ import (
)
from .util.event_type import EventType
from .util.executor import InterruptibleThreadPoolExecutor
from .util.hass_dict import HassDict
from .util.json import JsonObjectType
from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager
@ -406,7 +407,7 @@ class HomeAssistant:
from . import loader
# This is a dictionary that any component can store any data on.
self.data: dict[str, Any] = {}
self.data = HassDict()
self.loop = asyncio.get_running_loop()
self._tasks: set[asyncio.Future[Any]] = set()
self._background_tasks: set[asyncio.Future[Any]] = set()

View file

@ -5,17 +5,26 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
import functools
from typing import Any, TypeVar, cast
from typing import Any, TypeVar, cast, overload
from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey
_T = TypeVar("_T")
_FuncType = Callable[[HomeAssistant], _T]
def singleton(data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
@overload
def singleton(data_key: HassKey[_T]) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...
@overload
def singleton(data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...
def singleton(data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
"""Decorate a function that should be called once per instance.
Result will be cached and simultaneous calls will be handled.

View file

@ -33,6 +33,7 @@ from .helpers import singleton, translation
from .helpers.issue_registry import IssueSeverity, async_create_issue
from .helpers.typing import ConfigType
from .util.async_ import create_eager_task
from .util.hass_dict import HassKey
current_setup_group: contextvars.ContextVar[tuple[str, str | None] | None] = (
contextvars.ContextVar("current_setup_group", default=None)
@ -45,29 +46,32 @@ ATTR_COMPONENT: Final = "component"
BASE_PLATFORMS = {platform.value for platform in Platform}
# DATA_SETUP is a dict[str, asyncio.Future[bool]], indicating domains which are currently
# DATA_SETUP is a dict, indicating domains which are currently
# being setup or which failed to setup:
# - Tasks are added to DATA_SETUP by `async_setup_component`, the key is the domain
# being setup and the Task is the `_async_setup_component` helper.
# - Tasks are removed from DATA_SETUP if setup was successful, that is,
# the task returned True.
DATA_SETUP = "setup_tasks"
DATA_SETUP: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_tasks")
# DATA_SETUP_DONE is a dict [str, asyncio.Future[bool]], indicating components which
# will be setup:
# DATA_SETUP_DONE is a dict, indicating components which will be setup:
# - Events are added to DATA_SETUP_DONE during bootstrap by
# async_set_domains_to_be_loaded, the key is the domain which will be loaded.
# - Events are set and removed from DATA_SETUP_DONE when async_setup_component
# is finished, regardless of if the setup was successful or not.
DATA_SETUP_DONE = "setup_done"
DATA_SETUP_DONE: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_done")
# DATA_SETUP_STARTED is a dict [tuple[str, str | None], float], indicating when an attempt
# DATA_SETUP_STARTED is a dict, indicating when an attempt
# to setup a component started.
DATA_SETUP_STARTED = "setup_started"
DATA_SETUP_STARTED: HassKey[dict[tuple[str, str | None], float]] = HassKey(
"setup_started"
)
# DATA_SETUP_TIME is a defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]]
# indicating how time was spent setting up a component and each group (config entry).
DATA_SETUP_TIME = "setup_time"
# DATA_SETUP_TIME is a defaultdict, indicating how time was spent
# setting up a component.
DATA_SETUP_TIME: HassKey[
defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]]
] = HassKey("setup_time")
DATA_DEPS_REQS = "deps_reqs_processed"
@ -126,9 +130,7 @@ def async_set_domains_to_be_loaded(hass: core.HomeAssistant, domains: set[str])
- Properly handle after_dependencies.
- Keep track of domains which will load but have not yet finished loading
"""
setup_done_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP_DONE, {}
)
setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {})
setup_done_futures.update({domain: hass.loop.create_future() for domain in domains})
@ -149,12 +151,8 @@ async def async_setup_component(
if domain in hass.config.components:
return True
setup_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP, {}
)
setup_done_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP_DONE, {}
)
setup_futures = hass.data.setdefault(DATA_SETUP, {})
setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {})
if existing_setup_future := setup_futures.get(domain):
return await existing_setup_future
@ -195,9 +193,7 @@ async def _async_process_dependencies(
Returns a list of dependencies which failed to set up.
"""
setup_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP, {}
)
setup_futures = hass.data.setdefault(DATA_SETUP, {})
dependencies_tasks = {
dep: setup_futures.get(dep)
@ -210,7 +206,7 @@ async def _async_process_dependencies(
}
after_dependencies_tasks: dict[str, asyncio.Future[bool]] = {}
to_be_loaded: dict[str, asyncio.Future[bool]] = hass.data.get(DATA_SETUP_DONE, {})
to_be_loaded = hass.data.get(DATA_SETUP_DONE, {})
for dep in integration.after_dependencies:
if (
dep not in dependencies_tasks

View file

@ -0,0 +1,31 @@
"""Implementation for HassDict and custom HassKey types.
Custom for type checking. See stub file.
"""
from __future__ import annotations
from typing import Generic, TypeVar
_T = TypeVar("_T")
class HassKey(str, Generic[_T]):
"""Generic Hass key type.
At runtime this is a generic subclass of str.
"""
__slots__ = ()
class HassEntryKey(str, Generic[_T]):
"""Key type for integrations with config entries.
At runtime this is a generic subclass of str.
"""
__slots__ = ()
HassDict = dict

View file

@ -0,0 +1,176 @@
"""Stub file for hass_dict. Provide overload for type checking."""
# ruff: noqa: PYI021 # Allow docstrings
from typing import Any, Generic, TypeVar, assert_type, overload
__all__ = [
"HassDict",
"HassEntryKey",
"HassKey",
]
_T = TypeVar("_T")
_U = TypeVar("_U")
class _Key(Generic[_T]):
"""Base class for Hass key types. At runtime delegated to str."""
def __init__(self, value: str, /) -> None: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __getitem__(self, index: int) -> str: ...
class HassEntryKey(_Key[_T]):
"""Key type for integrations with config entries."""
class HassKey(_Key[_T]):
"""Generic Hass key type."""
class HassDict(dict[_Key[Any] | str, Any]):
"""Custom dict type to provide better value type hints for Hass key types."""
@overload # type: ignore[override]
def __getitem__(self, key: HassEntryKey[_T], /) -> dict[str, _T]: ...
@overload
def __getitem__(self, key: HassKey[_T], /) -> _T: ...
@overload
def __getitem__(self, key: str, /) -> Any: ...
# ------
@overload # type: ignore[override]
def __setitem__(self, key: HassEntryKey[_T], value: dict[str, _T], /) -> None: ...
@overload
def __setitem__(self, key: HassKey[_T], value: _T, /) -> None: ...
@overload
def __setitem__(self, key: str, value: Any, /) -> None: ...
# ------
@overload # type: ignore[override]
def setdefault(
self, key: HassEntryKey[_T], default: dict[str, _T], /
) -> dict[str, _T]: ...
@overload
def setdefault(self, key: HassKey[_T], default: _T, /) -> _T: ...
@overload
def setdefault(self, key: str, default: None = None, /) -> Any | None: ...
@overload
def setdefault(self, key: str, default: Any, /) -> Any: ...
# ------
@overload # type: ignore[override]
def get(self, key: HassEntryKey[_T], /) -> dict[str, _T] | None: ...
@overload
def get(self, key: HassEntryKey[_T], default: _U, /) -> dict[str, _T] | _U: ...
@overload
def get(self, key: HassKey[_T], /) -> _T | None: ...
@overload
def get(self, key: HassKey[_T], default: _U, /) -> _T | _U: ...
@overload
def get(self, key: str, /) -> Any | None: ...
@overload
def get(self, key: str, default: Any, /) -> Any: ...
# ------
@overload # type: ignore[override]
def pop(self, key: HassEntryKey[_T], /) -> dict[str, _T]: ...
@overload
def pop(
self, key: HassEntryKey[_T], default: dict[str, _T], /
) -> dict[str, _T]: ...
@overload
def pop(self, key: HassEntryKey[_T], default: _U, /) -> dict[str, _T] | _U: ...
@overload
def pop(self, key: HassKey[_T], /) -> _T: ...
@overload
def pop(self, key: HassKey[_T], default: _T, /) -> _T: ...
@overload
def pop(self, key: HassKey[_T], default: _U, /) -> _T | _U: ...
@overload
def pop(self, key: str, /) -> Any: ...
@overload
def pop(self, key: str, default: _U, /) -> Any | _U: ...
def _test_hass_dict_typing() -> None: # noqa: PYI048
"""Test HassDict overloads work as intended.
This is tested during the mypy run. Do not move it to 'tests'!
"""
d = HassDict()
entry_key = HassEntryKey[int]("entry_key")
key = HassKey[int]("key")
key2 = HassKey[dict[int, bool]]("key2")
key3 = HassKey[set[str]]("key3")
other_key = "domain"
# __getitem__
assert_type(d[entry_key], dict[str, int])
assert_type(d[entry_key]["entry_id"], int)
assert_type(d[key], int)
assert_type(d[key2], dict[int, bool])
# __setitem__
d[entry_key] = {}
d[entry_key] = 2 # type: ignore[call-overload]
d[entry_key]["entry_id"] = 2
d[entry_key]["entry_id"] = "Hello World" # type: ignore[assignment]
d[key] = 2
d[key] = "Hello World" # type: ignore[misc]
d[key] = {} # type: ignore[misc]
d[key2] = {}
d[key2] = 2 # type: ignore[misc]
d[key3] = set()
d[key3] = 2 # type: ignore[misc]
d[other_key] = 2
d[other_key] = "Hello World"
# get
assert_type(d.get(entry_key), dict[str, int] | None)
assert_type(d.get(entry_key, True), dict[str, int] | bool)
assert_type(d.get(key), int | None)
assert_type(d.get(key, True), int | bool)
assert_type(d.get(key2), dict[int, bool] | None)
assert_type(d.get(key2, {}), dict[int, bool])
assert_type(d.get(key3), set[str] | None)
assert_type(d.get(key3, set()), set[str])
assert_type(d.get(other_key), Any | None)
assert_type(d.get(other_key, True), Any)
assert_type(d.get(other_key, {})["id"], Any)
# setdefault
assert_type(d.setdefault(entry_key, {}), dict[str, int])
assert_type(d.setdefault(entry_key, {})["entry_id"], int)
assert_type(d.setdefault(key, 2), int)
assert_type(d.setdefault(key2, {}), dict[int, bool])
assert_type(d.setdefault(key2, {})[2], bool)
assert_type(d.setdefault(key3, set()), set[str])
assert_type(d.setdefault(other_key, 2), Any)
assert_type(d.setdefault(other_key), Any | None)
d.setdefault(entry_key, {})["entry_id"] = 2
d.setdefault(entry_key, {})["entry_id"] = "Hello World" # type: ignore[assignment]
d.setdefault(key, 2)
d.setdefault(key, "Error") # type: ignore[misc]
d.setdefault(key2, {})[2] = True
d.setdefault(key2, {})[2] = "Error" # type: ignore[assignment]
d.setdefault(key3, set()).add("Hello World")
d.setdefault(key3, set()).add(2) # type: ignore[arg-type]
d.setdefault(other_key, {})["id"] = 2
d.setdefault(other_key, {})["id"] = "Hello World"
d.setdefault(entry_key) # type: ignore[call-overload]
d.setdefault(key) # type: ignore[call-overload]
d.setdefault(key2) # type: ignore[call-overload]
# pop
assert_type(d.pop(entry_key), dict[str, int])
assert_type(d.pop(entry_key, {}), dict[str, int])
assert_type(d.pop(entry_key, 2), dict[str, int] | int)
assert_type(d.pop(key), int)
assert_type(d.pop(key, 2), int)
assert_type(d.pop(key, "Hello World"), int | str)
assert_type(d.pop(key2), dict[int, bool])
assert_type(d.pop(key2, {}), dict[int, bool])
assert_type(d.pop(key2, 2), dict[int, bool] | int)
assert_type(d.pop(key3), set[str])
assert_type(d.pop(key3, set()), set[str])
assert_type(d.pop(other_key), Any)
assert_type(d.pop(other_key, True), Any | bool)

View file

@ -739,7 +739,6 @@ async def test_integration_only_setup_entry(hass: HomeAssistant) -> None:
async def test_async_start_setup_running(hass: HomeAssistant) -> None:
"""Test setup started context manager does nothing when running."""
assert hass.state is CoreState.running
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
with setup.async_start_setup(
@ -753,7 +752,6 @@ async def test_async_start_setup_config_entry(
) -> None:
"""Test setup started keeps track of setup times with a config entry."""
hass.set_state(CoreState.not_running)
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
setup_time = setup._setup_times(hass)
@ -864,7 +862,6 @@ async def test_async_start_setup_config_entry_late_platform(
) -> None:
"""Test setup started tracks config entry time with a late platform load."""
hass.set_state(CoreState.not_running)
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
setup_time = setup._setup_times(hass)
@ -919,7 +916,6 @@ async def test_async_start_setup_config_entry_platform_wait(
) -> None:
"""Test setup started tracks wait time when a platform loads inside of config entry setup."""
hass.set_state(CoreState.not_running)
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
setup_time = setup._setup_times(hass)
@ -962,7 +958,6 @@ async def test_async_start_setup_config_entry_platform_wait(
async def test_async_start_setup_top_level_yaml(hass: HomeAssistant) -> None:
"""Test setup started context manager keeps track of setup times with modern yaml."""
hass.set_state(CoreState.not_running)
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
setup_time = setup._setup_times(hass)
@ -979,7 +974,6 @@ async def test_async_start_setup_top_level_yaml(hass: HomeAssistant) -> None:
async def test_async_start_setup_platform_integration(hass: HomeAssistant) -> None:
"""Test setup started keeps track of setup times a platform integration."""
hass.set_state(CoreState.not_running)
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
setup_time = setup._setup_times(hass)
@ -1014,7 +1008,6 @@ async def test_async_start_setup_legacy_platform_integration(
) -> None:
"""Test setup started keeps track of setup times for a legacy platform integration."""
hass.set_state(CoreState.not_running)
setup_started: dict[tuple[str, str | None], float]
setup_started = hass.data.setdefault(setup.DATA_SETUP_STARTED, {})
setup_time = setup._setup_times(hass)

View file

@ -0,0 +1,47 @@
"""Test HassDict and custom HassKey types."""
from homeassistant.util.hass_dict import HassDict, HassEntryKey, HassKey
def test_key_comparison() -> None:
"""Test key comparison with itself and string keys."""
str_key = "custom-key"
key = HassKey[int](str_key)
other_key = HassKey[str]("other-key")
entry_key = HassEntryKey[int](str_key)
other_entry_key = HassEntryKey[str]("other-key")
assert key == str_key
assert key != other_key
assert key != 2
assert entry_key == str_key
assert entry_key != other_entry_key
assert entry_key != 2
# Only compare name attribute, HassKey(<name>) == HassEntryKey(<name>)
assert key == entry_key
def test_hass_dict_access() -> None:
"""Test keys with the same name all access the same value in HassDict."""
data = HassDict()
str_key = "custom-key"
key = HassKey[int](str_key)
other_key = HassKey[str]("other-key")
entry_key = HassEntryKey[int](str_key)
other_entry_key = HassEntryKey[str]("other-key")
data[str_key] = True
assert data.get(key) is True
assert data.get(other_key) is None
assert data.get(entry_key) is True # type: ignore[comparison-overlap]
assert data.get(other_entry_key) is None
data[key] = False
assert data[str_key] is False