Add improved typing for event fire and listen methods (#114906)

* Add EventType implementation

* Update integrations for EventType

* Change state_changed to EventType

* Fix tests

* Remove runtime impact

* Add tests

* Move to stub file

* Apply pre-commit to stub files

* Fix ruff PYI checks

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Marc Mueller 2024-04-08 01:28:24 +02:00 committed by GitHub
parent d007b175c5
commit a0e6fd6ec5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 182 additions and 62 deletions

View file

@ -6,7 +6,7 @@ repos:
args:
- --fix
- id: ruff-format
files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.py$
files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.(py|pyi)$
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
@ -63,7 +63,7 @@ repos:
language: script
types: [python]
require_serial: true
files: ^(homeassistant|pylint)/.+\.py$
files: ^(homeassistant|pylint)/.+\.(py|pyi)$
- id: pylint
name: pylint
entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Final
from typing import Any, Final
from homeassistant.const import (
EVENT_COMPONENT_LOADED,
@ -21,10 +21,11 @@ from homeassistant.helpers.area_registry import EVENT_AREA_REGISTRY_UPDATED
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.issue_registry import EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED
from homeassistant.util.event_type import EventType
# These are events that do not contain any sensitive data
# Except for state_changed, which is handled accordingly.
SUBSCRIBE_ALLOWLIST: Final[set[str]] = {
SUBSCRIBE_ALLOWLIST: Final[set[EventType[Any] | str]] = {
EVENT_AREA_REGISTRY_UPDATED,
EVENT_COMPONENT_LOADED,
EVENT_CORE_CONFIG_UPDATE,

View file

@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from homeassistant.components.logbook import (
LOGBOOK_ENTRY_ICON,
@ -11,10 +12,11 @@ from homeassistant.components.logbook import (
)
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.event_type import EventType
from . import DOMAIN
EVENT_TO_NAME = {
EVENT_TO_NAME: dict[EventType[Any] | str, str] = {
EVENT_HOMEASSISTANT_STOP: "stopped",
EVENT_HOMEASSISTANT_START: "started",
}

View file

@ -31,6 +31,7 @@ from homeassistant.helpers.integration_platform import (
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util.event_type import EventType
from . import rest_api, websocket_api
from .const import ( # noqa: F401
@ -134,7 +135,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entities_filter = None
external_events: dict[
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
EventType[Any] | str,
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
] = {}
hass.data[DOMAIN] = LogbookConfig(external_events, filters, entities_filter)
websocket_api.async_setup(hass)

View file

@ -26,6 +26,7 @@ from homeassistant.core import (
)
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.util.event_type import EventType
from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN
from .models import LogbookConfig
@ -63,7 +64,7 @@ def _async_config_entries_for_ids(
def async_determine_event_types(
hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None
) -> tuple[str, ...]:
) -> tuple[EventType[Any] | str, ...]:
"""Reduce the event types based on the entity ids and device ids."""
logbook_config: LogbookConfig = hass.data[DOMAIN]
external_events = logbook_config.external_events
@ -81,7 +82,7 @@ def async_determine_event_types(
# to add them since we have historically included
# them when matching only on entities
#
intrested_event_types: set[str] = {
intrested_event_types: set[EventType[Any] | str] = {
external_event
for external_event, domain_call in external_events.items()
if domain_call[0] in interested_domains
@ -160,7 +161,7 @@ def async_subscribe_events(
hass: HomeAssistant,
subscriptions: list[CALLBACK_TYPE],
target: Callable[[Event[Any]], None],
event_types: tuple[str, ...],
event_types: tuple[EventType[Any] | str, ...],
entities_filter: Callable[[str], bool] | None,
entity_ids: list[str] | None,
device_ids: list[str] | None,

View file

@ -18,6 +18,7 @@ from homeassistant.components.recorder.models import (
)
from homeassistant.const import ATTR_ICON, EVENT_STATE_CHANGED
from homeassistant.core import Context, Event, State, callback
from homeassistant.util.event_type import EventType
from homeassistant.util.json import json_loads
from homeassistant.util.ulid import ulid_to_bytes
@ -27,7 +28,8 @@ class LogbookConfig:
"""Configuration for the logbook integration."""
external_events: dict[
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
EventType[Any] | str,
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
]
sqlalchemy_filter: Filters | None = None
entity_filter: Callable[[str], bool] | None = None
@ -66,7 +68,7 @@ class LazyEventPartialState:
)
@cached_property
def event_type(self) -> str | None:
def event_type(self) -> EventType[Any] | str | None:
"""Return the event type."""
return self.row.event_type
@ -110,7 +112,7 @@ class EventAsRow:
icon: str | None = None
context_user_id_bin: bytes | None = None
context_parent_id_bin: bytes | None = None
event_type: str | None = None
event_type: EventType[Any] | str | None = None
state: str | None = None
context_only: None = None

View file

@ -38,6 +38,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.helpers import entity_registry as er
import homeassistant.util.dt as dt_util
from homeassistant.util.event_type import EventType
from .const import (
ATTR_MESSAGE,
@ -75,7 +76,8 @@ class LogbookRun:
context_lookup: dict[bytes | None, Row | EventAsRow | None]
external_events: dict[
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
EventType[Any] | str,
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
]
event_cache: EventCache
entity_name_cache: EntityNameCache
@ -90,7 +92,7 @@ class EventProcessor:
def __init__(
self,
hass: HomeAssistant,
event_types: tuple[str, ...],
event_types: tuple[EventType[Any] | str, ...],
entity_ids: list[str] | None = None,
device_ids: list[str] | None = None,
context_id: str | None = None,

View file

@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from homeassistant.components.logbook import (
LOGBOOK_ENTRY_ENTITY_ID,
@ -12,6 +13,7 @@ from homeassistant.components.logbook import (
)
from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.event_type import EventType
from .const import DOMAIN
@ -21,7 +23,7 @@ IOS_EVENT_ZONE_EXITED = "ios.zone_exited"
ATTR_ZONE = "zone"
ATTR_SOURCE_DEVICE_NAME = "sourceDeviceName"
ATTR_SOURCE_DEVICE_ID = "sourceDeviceID"
EVENT_TO_DESCRIPTION = {
EVENT_TO_DESCRIPTION: dict[EventType[Any] | str, str] = {
IOS_EVENT_ZONE_ENTERED: "entered zone",
IOS_EVENT_ZONE_EXITED: "exited zone",
}

View file

@ -25,6 +25,7 @@ from homeassistant.helpers.integration_platform import (
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util.event_type import EventType
from . import entity_registry, websocket_api
from .const import ( # noqa: F401
@ -146,7 +147,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass_config_path=hass.config.path(DEFAULT_DB_FILE)
)
exclude = conf[CONF_EXCLUDE]
exclude_event_types: set[str] = set(exclude.get(CONF_EVENT_TYPES, []))
exclude_event_types: set[EventType[Any] | str] = set(
exclude.get(CONF_EVENT_TYPES, [])
)
if EVENT_STATE_CHANGED in exclude_event_types:
_LOGGER.error("State change events cannot be excluded, use a filter instead")
exclude_event_types.remove(EVENT_STATE_CHANGED)

View file

@ -40,6 +40,7 @@ from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
from homeassistant.util.enum import try_parse_enum
from homeassistant.util.event_type import EventType
from . import migration, statistics
from .const import (
@ -173,7 +174,7 @@ class Recorder(threading.Thread):
db_max_retries: int,
db_retry_wait: int,
entity_filter: Callable[[str], bool],
exclude_event_types: set[str],
exclude_event_types: set[EventType[Any] | str],
) -> None:
"""Initialize the recorder."""
threading.Thread.__init__(self, name="Recorder")

View file

@ -2,9 +2,13 @@
from __future__ import annotations
from typing import Any
from homeassistant.util.event_type import EventType
def extract_event_type_ids(
event_type_to_event_type_id: dict[str, int | None],
event_type_to_event_type_id: dict[EventType[Any] | str, int | None],
) -> list[int]:
"""Extract event_type ids from event_type_to_event_type_id."""
return [

View file

@ -1,9 +1,11 @@
"""Managers for each table."""
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from lru import LRU
from homeassistant.util.event_type import EventType
if TYPE_CHECKING:
from ..core import Recorder
@ -13,7 +15,7 @@ _DataT = TypeVar("_DataT")
class BaseTableManager(Generic[_DataT]):
"""Base class for table managers."""
_id_map: "LRU[str, int]"
_id_map: "LRU[EventType[Any] | str, int]"
def __init__(self, recorder: "Recorder") -> None:
"""Initialize the table manager.
@ -24,7 +26,7 @@ class BaseTableManager(Generic[_DataT]):
"""
self.active = False
self.recorder = recorder
self._pending: dict[str, _DataT] = {}
self._pending: dict[EventType[Any] | str, _DataT] = {}
def get_from_cache(self, data: str) -> int | None:
"""Resolve data to the id without accessing the underlying database.
@ -34,7 +36,7 @@ class BaseTableManager(Generic[_DataT]):
"""
return self._id_map.get(data)
def get_pending(self, shared_data: str) -> _DataT | None:
def get_pending(self, shared_data: EventType[Any] | str) -> _DataT | None:
"""Get pending data that have not be assigned ids yet.
This call is not thread-safe and must be called from the

View file

@ -3,12 +3,13 @@
from __future__ import annotations
from collections.abc import Iterable
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast
from lru import LRU
from sqlalchemy.orm.session import Session
from homeassistant.core import Event
from homeassistant.util.event_type import EventType
from ..db_schema import EventTypes
from ..queries import find_event_type_ids
@ -29,7 +30,9 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
def __init__(self, recorder: Recorder) -> None:
"""Initialize the event type manager."""
super().__init__(recorder, CACHE_SIZE)
self._non_existent_event_types: LRU[str, None] = LRU(CACHE_SIZE)
self._non_existent_event_types: LRU[EventType[Any] | str, None] = LRU(
CACHE_SIZE
)
def load(self, events: list[Event], session: Session) -> None:
"""Load the event_type to event_type_ids mapping into memory.
@ -44,7 +47,10 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
)
def get(
self, event_type: str, session: Session, from_recorder: bool = False
self,
event_type: EventType[Any] | str,
session: Session,
from_recorder: bool = False,
) -> int | None:
"""Resolve event_type to the event_type_id.
@ -54,16 +60,19 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
return self.get_many((event_type,), session)[event_type]
def get_many(
self, event_types: Iterable[str], session: Session, from_recorder: bool = False
) -> dict[str, int | None]:
self,
event_types: Iterable[EventType[Any] | str],
session: Session,
from_recorder: bool = False,
) -> dict[EventType[Any] | str, int | None]:
"""Resolve event_types to event_type_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
missing: list[str] = []
non_existent: list[str] = []
results: dict[EventType[Any] | str, int | None] = {}
missing: list[EventType[Any] | str] = []
non_existent: list[EventType[Any] | str] = []
for event_type in event_types:
if (event_type_id := self._id_map.get(event_type)) is None:
@ -123,7 +132,7 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
self.clear_non_existent(event_type)
self._pending.clear()
def clear_non_existent(self, event_type: str) -> None:
def clear_non_existent(self, event_type: EventType[Any] | str) -> None:
"""Clear a non-existent event type from the cache.
This call is not thread-safe and must be called from the

View file

@ -12,6 +12,7 @@ import threading
from typing import TYPE_CHECKING, Any
from homeassistant.helpers.typing import UndefinedType
from homeassistant.util.event_type import EventType
from . import entity_registry, purge, statistics
from .const import DOMAIN
@ -459,7 +460,7 @@ class EventIdMigrationTask(RecorderTask):
class RefreshEventTypesTask(RecorderTask):
"""An object to insert into the recorder queue to refresh event types."""
event_types: list[str]
event_types: list[EventType[Any] | str]
def run(self, instance: Recorder) -> None:
"""Refresh event types."""

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from enum import StrEnum
from functools import partial
from typing import Final
from typing import TYPE_CHECKING, Final
from .helpers.deprecation import (
DeprecatedConstant,
@ -13,8 +13,12 @@ from .helpers.deprecation import (
check_if_deprecated_constant,
dir_with_deprecated_constants,
)
from .util.event_type import EventType
from .util.signal_type import SignalType
if TYPE_CHECKING:
from .core import EventStateChangedData
APPLICATION_NAME: Final = "HomeAssistant"
MAJOR_VERSION: Final = 2024
MINOR_VERSION: Final = 5
@ -306,7 +310,7 @@ EVENT_LOGBOOK_ENTRY: Final = "logbook_entry"
EVENT_LOGGING_CHANGED: Final = "logging_changed"
EVENT_SERVICE_REGISTERED: Final = "service_registered"
EVENT_SERVICE_REMOVED: Final = "service_removed"
EVENT_STATE_CHANGED: Final = "state_changed"
EVENT_STATE_CHANGED: EventType[EventStateChangedData] = EventType("state_changed")
EVENT_STATE_REPORTED: Final = "state_reported"
EVENT_THEMES_UPDATED: Final = "themes_updated"
EVENT_PANELS_UPDATED: Final = "panels_updated"

View file

@ -102,6 +102,7 @@ from .util.async_ import (
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
from .util.event_type import EventType
from .util.executor import InterruptibleThreadPoolExecutor
from .util.json import JsonObjectType
from .util.read_only_dict import ReadOnlyDict
@ -1216,7 +1217,7 @@ class Event(Generic[_DataT]):
def __init__(
self,
event_type: str,
event_type: EventType[_DataT] | str,
data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local,
time_fired_timestamp: float | None = None,
@ -1290,7 +1291,7 @@ class Event(Generic[_DataT]):
def _event_repr(
event_type: str, origin: EventOrigin, data: Mapping[str, Any] | None
event_type: EventType[_DataT] | str, origin: EventOrigin, data: _DataT | None
) -> str:
"""Return the representation."""
if data:
@ -1307,13 +1308,13 @@ _FilterableJobType = tuple[
@dataclass(slots=True)
class _OneTimeListener:
class _OneTimeListener(Generic[_DataT]):
hass: HomeAssistant
listener_job: HassJob[[Event], Coroutine[Any, Any, None] | None]
listener_job: HassJob[[Event[_DataT]], Coroutine[Any, Any, None] | None]
remove: CALLBACK_TYPE | None = None
@callback
def __call__(self, event: Event) -> None:
def __call__(self, event: Event[_DataT]) -> None:
"""Remove listener from event bus and then fire listener."""
if not self.remove:
# If the listener was already removed, we don't need to do anything
@ -1341,7 +1342,7 @@ class EventBus:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus."""
self._listeners: dict[str, list[_FilterableJobType[Any]]] = {}
self._listeners: dict[EventType[Any] | str, list[_FilterableJobType[Any]]] = {}
self._match_all_listeners: list[_FilterableJobType[Any]] = []
self._listeners[MATCH_ALL] = self._match_all_listeners
self._hass = hass
@ -1356,7 +1357,7 @@ class EventBus:
self._debug = _LOGGER.isEnabledFor(logging.DEBUG)
@callback
def async_listeners(self) -> dict[str, int]:
def async_listeners(self) -> dict[EventType[Any] | str, int]:
"""Return dictionary with events and the number of listeners.
This method must be run in the event loop.
@ -1364,14 +1365,14 @@ class EventBus:
return {key: len(listeners) for key, listeners in self._listeners.items()}
@property
def listeners(self) -> dict[str, int]:
def listeners(self) -> dict[EventType[Any] | str, int]:
"""Return dictionary with events and the number of listeners."""
return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
def fire(
self,
event_type: str,
event_data: Mapping[str, Any] | None = None,
event_type: EventType[_DataT] | str,
event_data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local,
context: Context | None = None,
) -> None:
@ -1383,8 +1384,8 @@ class EventBus:
@callback
def async_fire(
self,
event_type: str,
event_data: Mapping[str, Any] | None = None,
event_type: EventType[_DataT] | str,
event_data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local,
context: Context | None = None,
time_fired: float | None = None,
@ -1402,8 +1403,8 @@ class EventBus:
@callback
def _async_fire(
self,
event_type: str,
event_data: Mapping[str, Any] | None = None,
event_type: EventType[_DataT] | str,
event_data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local,
context: Context | None = None,
time_fired: float | None = None,
@ -1431,7 +1432,7 @@ class EventBus:
if not listeners:
return
event: Event | None = None
event: Event[_DataT] | None = None
for job, event_filter, run_immediately in listeners:
if event_filter is not None:
@ -1461,8 +1462,8 @@ class EventBus:
def listen(
self,
event_type: str,
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None],
event_type: EventType[_DataT] | str,
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type.
@ -1482,7 +1483,7 @@ class EventBus:
@callback
def async_listen(
self,
event_type: str,
event_type: EventType[_DataT] | str,
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
event_filter: Callable[[_DataT], bool] | None = None,
run_immediately: bool = True,
@ -1524,7 +1525,9 @@ class EventBus:
@callback
def _async_listen_filterable_job(
self, event_type: str, filterable_job: _FilterableJobType[Any]
self,
event_type: EventType[_DataT] | str,
filterable_job: _FilterableJobType[_DataT],
) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(filterable_job)
return functools.partial(
@ -1533,8 +1536,8 @@ class EventBus:
def listen_once(
self,
event_type: str,
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None],
event_type: EventType[_DataT] | str,
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
) -> CALLBACK_TYPE:
"""Listen once for event of a specific type.
@ -1556,8 +1559,8 @@ class EventBus:
@callback
def async_listen_once(
self,
event_type: str,
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None],
event_type: EventType[_DataT] | str,
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
run_immediately: bool = True,
) -> CALLBACK_TYPE:
"""Listen once for event of a specific type.
@ -1569,7 +1572,9 @@ class EventBus:
This method must be run in the event loop.
"""
one_time_listener = _OneTimeListener(self._hass, HassJob(listener))
one_time_listener: _OneTimeListener[_DataT] = _OneTimeListener(
self._hass, HassJob(listener)
)
remove = self._async_listen_filterable_job(
event_type,
(
@ -1587,7 +1592,9 @@ class EventBus:
@callback
def _async_remove_listener(
self, event_type: str, filterable_job: _FilterableJobType
self,
event_type: EventType[_DataT] | str,
filterable_job: _FilterableJobType[_DataT],
) -> None:
"""Remove a listener of a specific event_type.

View file

@ -4,7 +4,9 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from .util.event_type import EventType
if TYPE_CHECKING:
from .core import Context
@ -271,8 +273,12 @@ class ServiceNotFound(HomeAssistantError):
class MaxLengthExceeded(HomeAssistantError):
"""Raised when a property value has exceeded the max character length."""
def __init__(self, value: str, property_name: str, max_length: int) -> None:
def __init__(
self, value: EventType[Any] | str, property_name: str, max_length: int
) -> None:
"""Initialize error."""
if TYPE_CHECKING:
value = str(value)
super().__init__(
translation_domain="homeassistant",
translation_key="max_length_exceeded",

View file

@ -38,6 +38,7 @@ from homeassistant.exceptions import TemplateError
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.event_type import EventType
from .device_registry import (
EVENT_DEVICE_REGISTRY_UPDATED,
@ -90,7 +91,7 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
listeners_key: str
callbacks_key: str
event_type: str
event_type: EventType[_TypedDictT] | str
dispatcher_callable: Callable[
[
HomeAssistant,

View file

@ -0,0 +1,20 @@
"""Implementation for EventType.
Custom for type checking. See stub file.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Generic
from typing_extensions import TypeVar
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
class EventType(str, Generic[_DataT]):
"""Custom type for Event.event_type.
At runtime this is a generic subclass of str.
"""

View file

@ -0,0 +1,25 @@
"""Stub file for event_type. Provide overload for type checking."""
# ruff: noqa: PYI021 # Allow docstrings
from collections.abc import Mapping
from typing import Any, Generic
from typing_extensions import TypeVar
__all__ = [
"EventType",
]
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
class EventType(Generic[_DataT]):
"""Custom type for Event.event_type. At runtime delegated to str.
For type checkers pretend to be its own separate class.
"""
def __init__(self, value: str, /) -> None: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
def __eq__(self, value: object, /) -> bool: ...
def __getitem__(self, index: int) -> str: ...

View file

@ -0,0 +1,25 @@
"""Test EventType implementation."""
from __future__ import annotations
import orjson
from homeassistant.util.event_type import EventType
def test_compatibility_with_str() -> None:
"""Test EventType. At runtime it should be (almost) fully compatible with str."""
event = EventType("Hello World")
assert event == "Hello World"
assert len(event) == 11
assert hash(event) == hash("Hello World")
d: dict[str | EventType, int] = {EventType("key"): 2}
assert d["key"] == 2
def test_json_dump() -> None:
"""Test EventType json dump with orjson."""
event = EventType("state_changed")
assert orjson.dumps({"event_type": event}) == b'{"event_type":"state_changed"}'