Add improved typing for event fire and listen methods (#114906)

* Add EventType implementation

* Update integrations for EventType

* Change state_changed to EventType

* Fix tests

* Remove runtime impact

* Add tests

* Move to stub file

* Apply pre-commit to stub files

* Fix ruff PYI checks

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Marc Mueller 2024-04-08 01:28:24 +02:00 committed by GitHub
parent d007b175c5
commit a0e6fd6ec5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 182 additions and 62 deletions

View file

@ -6,7 +6,7 @@ repos:
args: args:
- --fix - --fix
- id: ruff-format - id: ruff-format
files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.py$ files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.(py|pyi)$
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.2.6 rev: v2.2.6
hooks: hooks:
@ -63,7 +63,7 @@ repos:
language: script language: script
types: [python] types: [python]
require_serial: true require_serial: true
files: ^(homeassistant|pylint)/.+\.py$ files: ^(homeassistant|pylint)/.+\.(py|pyi)$
- id: pylint - id: pylint
name: pylint name: pylint
entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Final from typing import Any, Final
from homeassistant.const import ( from homeassistant.const import (
EVENT_COMPONENT_LOADED, EVENT_COMPONENT_LOADED,
@ -21,10 +21,11 @@ from homeassistant.helpers.area_registry import EVENT_AREA_REGISTRY_UPDATED
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.issue_registry import EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED from homeassistant.helpers.issue_registry import EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED
from homeassistant.util.event_type import EventType
# These are events that do not contain any sensitive data # These are events that do not contain any sensitive data
# Except for state_changed, which is handled accordingly. # Except for state_changed, which is handled accordingly.
SUBSCRIBE_ALLOWLIST: Final[set[str]] = { SUBSCRIBE_ALLOWLIST: Final[set[EventType[Any] | str]] = {
EVENT_AREA_REGISTRY_UPDATED, EVENT_AREA_REGISTRY_UPDATED,
EVENT_COMPONENT_LOADED, EVENT_COMPONENT_LOADED,
EVENT_CORE_CONFIG_UPDATE, EVENT_CORE_CONFIG_UPDATE,

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Any
from homeassistant.components.logbook import ( from homeassistant.components.logbook import (
LOGBOOK_ENTRY_ICON, LOGBOOK_ENTRY_ICON,
@ -11,10 +12,11 @@ from homeassistant.components.logbook import (
) )
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.event_type import EventType
from . import DOMAIN from . import DOMAIN
EVENT_TO_NAME = { EVENT_TO_NAME: dict[EventType[Any] | str, str] = {
EVENT_HOMEASSISTANT_STOP: "stopped", EVENT_HOMEASSISTANT_STOP: "stopped",
EVENT_HOMEASSISTANT_START: "started", EVENT_HOMEASSISTANT_START: "started",
} }

View file

@ -31,6 +31,7 @@ from homeassistant.helpers.integration_platform import (
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.event_type import EventType
from . import rest_api, websocket_api from . import rest_api, websocket_api
from .const import ( # noqa: F401 from .const import ( # noqa: F401
@ -134,7 +135,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entities_filter = None entities_filter = None
external_events: dict[ external_events: dict[
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] EventType[Any] | str,
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
] = {} ] = {}
hass.data[DOMAIN] = LogbookConfig(external_events, filters, entities_filter) hass.data[DOMAIN] = LogbookConfig(external_events, filters, entities_filter)
websocket_api.async_setup(hass) websocket_api.async_setup(hass)

View file

@ -26,6 +26,7 @@ from homeassistant.core import (
) )
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.util.event_type import EventType
from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN
from .models import LogbookConfig from .models import LogbookConfig
@ -63,7 +64,7 @@ def _async_config_entries_for_ids(
def async_determine_event_types( def async_determine_event_types(
hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None
) -> tuple[str, ...]: ) -> tuple[EventType[Any] | str, ...]:
"""Reduce the event types based on the entity ids and device ids.""" """Reduce the event types based on the entity ids and device ids."""
logbook_config: LogbookConfig = hass.data[DOMAIN] logbook_config: LogbookConfig = hass.data[DOMAIN]
external_events = logbook_config.external_events external_events = logbook_config.external_events
@ -81,7 +82,7 @@ def async_determine_event_types(
# to add them since we have historically included # to add them since we have historically included
# them when matching only on entities # them when matching only on entities
# #
intrested_event_types: set[str] = { intrested_event_types: set[EventType[Any] | str] = {
external_event external_event
for external_event, domain_call in external_events.items() for external_event, domain_call in external_events.items()
if domain_call[0] in interested_domains if domain_call[0] in interested_domains
@ -160,7 +161,7 @@ def async_subscribe_events(
hass: HomeAssistant, hass: HomeAssistant,
subscriptions: list[CALLBACK_TYPE], subscriptions: list[CALLBACK_TYPE],
target: Callable[[Event[Any]], None], target: Callable[[Event[Any]], None],
event_types: tuple[str, ...], event_types: tuple[EventType[Any] | str, ...],
entities_filter: Callable[[str], bool] | None, entities_filter: Callable[[str], bool] | None,
entity_ids: list[str] | None, entity_ids: list[str] | None,
device_ids: list[str] | None, device_ids: list[str] | None,

View file

@ -18,6 +18,7 @@ from homeassistant.components.recorder.models import (
) )
from homeassistant.const import ATTR_ICON, EVENT_STATE_CHANGED from homeassistant.const import ATTR_ICON, EVENT_STATE_CHANGED
from homeassistant.core import Context, Event, State, callback from homeassistant.core import Context, Event, State, callback
from homeassistant.util.event_type import EventType
from homeassistant.util.json import json_loads from homeassistant.util.json import json_loads
from homeassistant.util.ulid import ulid_to_bytes from homeassistant.util.ulid import ulid_to_bytes
@ -27,7 +28,8 @@ class LogbookConfig:
"""Configuration for the logbook integration.""" """Configuration for the logbook integration."""
external_events: dict[ external_events: dict[
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] EventType[Any] | str,
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
] ]
sqlalchemy_filter: Filters | None = None sqlalchemy_filter: Filters | None = None
entity_filter: Callable[[str], bool] | None = None entity_filter: Callable[[str], bool] | None = None
@ -66,7 +68,7 @@ class LazyEventPartialState:
) )
@cached_property @cached_property
def event_type(self) -> str | None: def event_type(self) -> EventType[Any] | str | None:
"""Return the event type.""" """Return the event type."""
return self.row.event_type return self.row.event_type
@ -110,7 +112,7 @@ class EventAsRow:
icon: str | None = None icon: str | None = None
context_user_id_bin: bytes | None = None context_user_id_bin: bytes | None = None
context_parent_id_bin: bytes | None = None context_parent_id_bin: bytes | None = None
event_type: str | None = None event_type: EventType[Any] | str | None = None
state: str | None = None state: str | None = None
context_only: None = None context_only: None = None

View file

@ -38,6 +38,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, split_entity_id from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.event_type import EventType
from .const import ( from .const import (
ATTR_MESSAGE, ATTR_MESSAGE,
@ -75,7 +76,8 @@ class LogbookRun:
context_lookup: dict[bytes | None, Row | EventAsRow | None] context_lookup: dict[bytes | None, Row | EventAsRow | None]
external_events: dict[ external_events: dict[
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] EventType[Any] | str,
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
] ]
event_cache: EventCache event_cache: EventCache
entity_name_cache: EntityNameCache entity_name_cache: EntityNameCache
@ -90,7 +92,7 @@ class EventProcessor:
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
event_types: tuple[str, ...], event_types: tuple[EventType[Any] | str, ...],
entity_ids: list[str] | None = None, entity_ids: list[str] | None = None,
device_ids: list[str] | None = None, device_ids: list[str] | None = None,
context_id: str | None = None, context_id: str | None = None,

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Any
from homeassistant.components.logbook import ( from homeassistant.components.logbook import (
LOGBOOK_ENTRY_ENTITY_ID, LOGBOOK_ENTRY_ENTITY_ID,
@ -12,6 +13,7 @@ from homeassistant.components.logbook import (
) )
from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.event_type import EventType
from .const import DOMAIN from .const import DOMAIN
@ -21,7 +23,7 @@ IOS_EVENT_ZONE_EXITED = "ios.zone_exited"
ATTR_ZONE = "zone" ATTR_ZONE = "zone"
ATTR_SOURCE_DEVICE_NAME = "sourceDeviceName" ATTR_SOURCE_DEVICE_NAME = "sourceDeviceName"
ATTR_SOURCE_DEVICE_ID = "sourceDeviceID" ATTR_SOURCE_DEVICE_ID = "sourceDeviceID"
EVENT_TO_DESCRIPTION = { EVENT_TO_DESCRIPTION: dict[EventType[Any] | str, str] = {
IOS_EVENT_ZONE_ENTERED: "entered zone", IOS_EVENT_ZONE_ENTERED: "entered zone",
IOS_EVENT_ZONE_EXITED: "exited zone", IOS_EVENT_ZONE_EXITED: "exited zone",
} }

View file

@ -25,6 +25,7 @@ from homeassistant.helpers.integration_platform import (
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.event_type import EventType
from . import entity_registry, websocket_api from . import entity_registry, websocket_api
from .const import ( # noqa: F401 from .const import ( # noqa: F401
@ -146,7 +147,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass_config_path=hass.config.path(DEFAULT_DB_FILE) hass_config_path=hass.config.path(DEFAULT_DB_FILE)
) )
exclude = conf[CONF_EXCLUDE] exclude = conf[CONF_EXCLUDE]
exclude_event_types: set[str] = set(exclude.get(CONF_EVENT_TYPES, [])) exclude_event_types: set[EventType[Any] | str] = set(
exclude.get(CONF_EVENT_TYPES, [])
)
if EVENT_STATE_CHANGED in exclude_event_types: if EVENT_STATE_CHANGED in exclude_event_types:
_LOGGER.error("State change events cannot be excluded, use a filter instead") _LOGGER.error("State change events cannot be excluded, use a filter instead")
exclude_event_types.remove(EVENT_STATE_CHANGED) exclude_event_types.remove(EVENT_STATE_CHANGED)

View file

@ -40,6 +40,7 @@ from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.enum import try_parse_enum from homeassistant.util.enum import try_parse_enum
from homeassistant.util.event_type import EventType
from . import migration, statistics from . import migration, statistics
from .const import ( from .const import (
@ -173,7 +174,7 @@ class Recorder(threading.Thread):
db_max_retries: int, db_max_retries: int,
db_retry_wait: int, db_retry_wait: int,
entity_filter: Callable[[str], bool], entity_filter: Callable[[str], bool],
exclude_event_types: set[str], exclude_event_types: set[EventType[Any] | str],
) -> None: ) -> None:
"""Initialize the recorder.""" """Initialize the recorder."""
threading.Thread.__init__(self, name="Recorder") threading.Thread.__init__(self, name="Recorder")

View file

@ -2,9 +2,13 @@
from __future__ import annotations from __future__ import annotations
from typing import Any
from homeassistant.util.event_type import EventType
def extract_event_type_ids( def extract_event_type_ids(
event_type_to_event_type_id: dict[str, int | None], event_type_to_event_type_id: dict[EventType[Any] | str, int | None],
) -> list[int]: ) -> list[int]:
"""Extract event_type ids from event_type_to_event_type_id.""" """Extract event_type ids from event_type_to_event_type_id."""
return [ return [

View file

@ -1,9 +1,11 @@
"""Managers for each table.""" """Managers for each table."""
from typing import TYPE_CHECKING, Generic, TypeVar from typing import TYPE_CHECKING, Any, Generic, TypeVar
from lru import LRU from lru import LRU
from homeassistant.util.event_type import EventType
if TYPE_CHECKING: if TYPE_CHECKING:
from ..core import Recorder from ..core import Recorder
@ -13,7 +15,7 @@ _DataT = TypeVar("_DataT")
class BaseTableManager(Generic[_DataT]): class BaseTableManager(Generic[_DataT]):
"""Base class for table managers.""" """Base class for table managers."""
_id_map: "LRU[str, int]" _id_map: "LRU[EventType[Any] | str, int]"
def __init__(self, recorder: "Recorder") -> None: def __init__(self, recorder: "Recorder") -> None:
"""Initialize the table manager. """Initialize the table manager.
@ -24,7 +26,7 @@ class BaseTableManager(Generic[_DataT]):
""" """
self.active = False self.active = False
self.recorder = recorder self.recorder = recorder
self._pending: dict[str, _DataT] = {} self._pending: dict[EventType[Any] | str, _DataT] = {}
def get_from_cache(self, data: str) -> int | None: def get_from_cache(self, data: str) -> int | None:
"""Resolve data to the id without accessing the underlying database. """Resolve data to the id without accessing the underlying database.
@ -34,7 +36,7 @@ class BaseTableManager(Generic[_DataT]):
""" """
return self._id_map.get(data) return self._id_map.get(data)
def get_pending(self, shared_data: str) -> _DataT | None: def get_pending(self, shared_data: EventType[Any] | str) -> _DataT | None:
"""Get pending data that have not be assigned ids yet. """Get pending data that have not be assigned ids yet.
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the

View file

@ -3,12 +3,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, Any, cast
from lru import LRU from lru import LRU
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event
from homeassistant.util.event_type import EventType
from ..db_schema import EventTypes from ..db_schema import EventTypes
from ..queries import find_event_type_ids from ..queries import find_event_type_ids
@ -29,7 +30,9 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
def __init__(self, recorder: Recorder) -> None: def __init__(self, recorder: Recorder) -> None:
"""Initialize the event type manager.""" """Initialize the event type manager."""
super().__init__(recorder, CACHE_SIZE) super().__init__(recorder, CACHE_SIZE)
self._non_existent_event_types: LRU[str, None] = LRU(CACHE_SIZE) self._non_existent_event_types: LRU[EventType[Any] | str, None] = LRU(
CACHE_SIZE
)
def load(self, events: list[Event], session: Session) -> None: def load(self, events: list[Event], session: Session) -> None:
"""Load the event_type to event_type_ids mapping into memory. """Load the event_type to event_type_ids mapping into memory.
@ -44,7 +47,10 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
) )
def get( def get(
self, event_type: str, session: Session, from_recorder: bool = False self,
event_type: EventType[Any] | str,
session: Session,
from_recorder: bool = False,
) -> int | None: ) -> int | None:
"""Resolve event_type to the event_type_id. """Resolve event_type to the event_type_id.
@ -54,16 +60,19 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
return self.get_many((event_type,), session)[event_type] return self.get_many((event_type,), session)[event_type]
def get_many( def get_many(
self, event_types: Iterable[str], session: Session, from_recorder: bool = False self,
) -> dict[str, int | None]: event_types: Iterable[EventType[Any] | str],
session: Session,
from_recorder: bool = False,
) -> dict[EventType[Any] | str, int | None]:
"""Resolve event_types to event_type_ids. """Resolve event_types to event_type_ids.
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the
recorder thread. recorder thread.
""" """
results: dict[str, int | None] = {} results: dict[EventType[Any] | str, int | None] = {}
missing: list[str] = [] missing: list[EventType[Any] | str] = []
non_existent: list[str] = [] non_existent: list[EventType[Any] | str] = []
for event_type in event_types: for event_type in event_types:
if (event_type_id := self._id_map.get(event_type)) is None: if (event_type_id := self._id_map.get(event_type)) is None:
@ -123,7 +132,7 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
self.clear_non_existent(event_type) self.clear_non_existent(event_type)
self._pending.clear() self._pending.clear()
def clear_non_existent(self, event_type: str) -> None: def clear_non_existent(self, event_type: EventType[Any] | str) -> None:
"""Clear a non-existent event type from the cache. """Clear a non-existent event type from the cache.
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the

View file

@ -12,6 +12,7 @@ import threading
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from homeassistant.helpers.typing import UndefinedType from homeassistant.helpers.typing import UndefinedType
from homeassistant.util.event_type import EventType
from . import entity_registry, purge, statistics from . import entity_registry, purge, statistics
from .const import DOMAIN from .const import DOMAIN
@ -459,7 +460,7 @@ class EventIdMigrationTask(RecorderTask):
class RefreshEventTypesTask(RecorderTask): class RefreshEventTypesTask(RecorderTask):
"""An object to insert into the recorder queue to refresh event types.""" """An object to insert into the recorder queue to refresh event types."""
event_types: list[str] event_types: list[EventType[Any] | str]
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Refresh event types.""" """Refresh event types."""

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from enum import StrEnum from enum import StrEnum
from functools import partial from functools import partial
from typing import Final from typing import TYPE_CHECKING, Final
from .helpers.deprecation import ( from .helpers.deprecation import (
DeprecatedConstant, DeprecatedConstant,
@ -13,8 +13,12 @@ from .helpers.deprecation import (
check_if_deprecated_constant, check_if_deprecated_constant,
dir_with_deprecated_constants, dir_with_deprecated_constants,
) )
from .util.event_type import EventType
from .util.signal_type import SignalType from .util.signal_type import SignalType
if TYPE_CHECKING:
from .core import EventStateChangedData
APPLICATION_NAME: Final = "HomeAssistant" APPLICATION_NAME: Final = "HomeAssistant"
MAJOR_VERSION: Final = 2024 MAJOR_VERSION: Final = 2024
MINOR_VERSION: Final = 5 MINOR_VERSION: Final = 5
@ -306,7 +310,7 @@ EVENT_LOGBOOK_ENTRY: Final = "logbook_entry"
EVENT_LOGGING_CHANGED: Final = "logging_changed" EVENT_LOGGING_CHANGED: Final = "logging_changed"
EVENT_SERVICE_REGISTERED: Final = "service_registered" EVENT_SERVICE_REGISTERED: Final = "service_registered"
EVENT_SERVICE_REMOVED: Final = "service_removed" EVENT_SERVICE_REMOVED: Final = "service_removed"
EVENT_STATE_CHANGED: Final = "state_changed" EVENT_STATE_CHANGED: EventType[EventStateChangedData] = EventType("state_changed")
EVENT_STATE_REPORTED: Final = "state_reported" EVENT_STATE_REPORTED: Final = "state_reported"
EVENT_THEMES_UPDATED: Final = "themes_updated" EVENT_THEMES_UPDATED: Final = "themes_updated"
EVENT_PANELS_UPDATED: Final = "panels_updated" EVENT_PANELS_UPDATED: Final = "panels_updated"

View file

@ -102,6 +102,7 @@ from .util.async_ import (
run_callback_threadsafe, run_callback_threadsafe,
shutdown_run_callback_threadsafe, shutdown_run_callback_threadsafe,
) )
from .util.event_type import EventType
from .util.executor import InterruptibleThreadPoolExecutor from .util.executor import InterruptibleThreadPoolExecutor
from .util.json import JsonObjectType from .util.json import JsonObjectType
from .util.read_only_dict import ReadOnlyDict from .util.read_only_dict import ReadOnlyDict
@ -1216,7 +1217,7 @@ class Event(Generic[_DataT]):
def __init__( def __init__(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
data: _DataT | None = None, data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
time_fired_timestamp: float | None = None, time_fired_timestamp: float | None = None,
@ -1290,7 +1291,7 @@ class Event(Generic[_DataT]):
def _event_repr( def _event_repr(
event_type: str, origin: EventOrigin, data: Mapping[str, Any] | None event_type: EventType[_DataT] | str, origin: EventOrigin, data: _DataT | None
) -> str: ) -> str:
"""Return the representation.""" """Return the representation."""
if data: if data:
@ -1307,13 +1308,13 @@ _FilterableJobType = tuple[
@dataclass(slots=True) @dataclass(slots=True)
class _OneTimeListener: class _OneTimeListener(Generic[_DataT]):
hass: HomeAssistant hass: HomeAssistant
listener_job: HassJob[[Event], Coroutine[Any, Any, None] | None] listener_job: HassJob[[Event[_DataT]], Coroutine[Any, Any, None] | None]
remove: CALLBACK_TYPE | None = None remove: CALLBACK_TYPE | None = None
@callback @callback
def __call__(self, event: Event) -> None: def __call__(self, event: Event[_DataT]) -> None:
"""Remove listener from event bus and then fire listener.""" """Remove listener from event bus and then fire listener."""
if not self.remove: if not self.remove:
# If the listener was already removed, we don't need to do anything # If the listener was already removed, we don't need to do anything
@ -1341,7 +1342,7 @@ class EventBus:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus.""" """Initialize a new event bus."""
self._listeners: dict[str, list[_FilterableJobType[Any]]] = {} self._listeners: dict[EventType[Any] | str, list[_FilterableJobType[Any]]] = {}
self._match_all_listeners: list[_FilterableJobType[Any]] = [] self._match_all_listeners: list[_FilterableJobType[Any]] = []
self._listeners[MATCH_ALL] = self._match_all_listeners self._listeners[MATCH_ALL] = self._match_all_listeners
self._hass = hass self._hass = hass
@ -1356,7 +1357,7 @@ class EventBus:
self._debug = _LOGGER.isEnabledFor(logging.DEBUG) self._debug = _LOGGER.isEnabledFor(logging.DEBUG)
@callback @callback
def async_listeners(self) -> dict[str, int]: def async_listeners(self) -> dict[EventType[Any] | str, int]:
"""Return dictionary with events and the number of listeners. """Return dictionary with events and the number of listeners.
This method must be run in the event loop. This method must be run in the event loop.
@ -1364,14 +1365,14 @@ class EventBus:
return {key: len(listeners) for key, listeners in self._listeners.items()} return {key: len(listeners) for key, listeners in self._listeners.items()}
@property @property
def listeners(self) -> dict[str, int]: def listeners(self) -> dict[EventType[Any] | str, int]:
"""Return dictionary with events and the number of listeners.""" """Return dictionary with events and the number of listeners."""
return run_callback_threadsafe(self._hass.loop, self.async_listeners).result() return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
def fire( def fire(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
event_data: Mapping[str, Any] | None = None, event_data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
context: Context | None = None, context: Context | None = None,
) -> None: ) -> None:
@ -1383,8 +1384,8 @@ class EventBus:
@callback @callback
def async_fire( def async_fire(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
event_data: Mapping[str, Any] | None = None, event_data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
context: Context | None = None, context: Context | None = None,
time_fired: float | None = None, time_fired: float | None = None,
@ -1402,8 +1403,8 @@ class EventBus:
@callback @callback
def _async_fire( def _async_fire(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
event_data: Mapping[str, Any] | None = None, event_data: _DataT | None = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
context: Context | None = None, context: Context | None = None,
time_fired: float | None = None, time_fired: float | None = None,
@ -1431,7 +1432,7 @@ class EventBus:
if not listeners: if not listeners:
return return
event: Event | None = None event: Event[_DataT] | None = None
for job, event_filter, run_immediately in listeners: for job, event_filter, run_immediately in listeners:
if event_filter is not None: if event_filter is not None:
@ -1461,8 +1462,8 @@ class EventBus:
def listen( def listen(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None], listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type. """Listen for all events or events of a specific type.
@ -1482,7 +1483,7 @@ class EventBus:
@callback @callback
def async_listen( def async_listen(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None], listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
event_filter: Callable[[_DataT], bool] | None = None, event_filter: Callable[[_DataT], bool] | None = None,
run_immediately: bool = True, run_immediately: bool = True,
@ -1524,7 +1525,9 @@ class EventBus:
@callback @callback
def _async_listen_filterable_job( def _async_listen_filterable_job(
self, event_type: str, filterable_job: _FilterableJobType[Any] self,
event_type: EventType[_DataT] | str,
filterable_job: _FilterableJobType[_DataT],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(filterable_job) self._listeners.setdefault(event_type, []).append(filterable_job)
return functools.partial( return functools.partial(
@ -1533,8 +1536,8 @@ class EventBus:
def listen_once( def listen_once(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None], listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Listen once for event of a specific type. """Listen once for event of a specific type.
@ -1556,8 +1559,8 @@ class EventBus:
@callback @callback
def async_listen_once( def async_listen_once(
self, self,
event_type: str, event_type: EventType[_DataT] | str,
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None], listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
run_immediately: bool = True, run_immediately: bool = True,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Listen once for event of a specific type. """Listen once for event of a specific type.
@ -1569,7 +1572,9 @@ class EventBus:
This method must be run in the event loop. This method must be run in the event loop.
""" """
one_time_listener = _OneTimeListener(self._hass, HassJob(listener)) one_time_listener: _OneTimeListener[_DataT] = _OneTimeListener(
self._hass, HassJob(listener)
)
remove = self._async_listen_filterable_job( remove = self._async_listen_filterable_job(
event_type, event_type,
( (
@ -1587,7 +1592,9 @@ class EventBus:
@callback @callback
def _async_remove_listener( def _async_remove_listener(
self, event_type: str, filterable_job: _FilterableJobType self,
event_type: EventType[_DataT] | str,
filterable_job: _FilterableJobType[_DataT],
) -> None: ) -> None:
"""Remove a listener of a specific event_type. """Remove a listener of a specific event_type.

View file

@ -4,7 +4,9 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from .util.event_type import EventType
if TYPE_CHECKING: if TYPE_CHECKING:
from .core import Context from .core import Context
@ -271,8 +273,12 @@ class ServiceNotFound(HomeAssistantError):
class MaxLengthExceeded(HomeAssistantError): class MaxLengthExceeded(HomeAssistantError):
"""Raised when a property value has exceeded the max character length.""" """Raised when a property value has exceeded the max character length."""
def __init__(self, value: str, property_name: str, max_length: int) -> None: def __init__(
self, value: EventType[Any] | str, property_name: str, max_length: int
) -> None:
"""Initialize error.""" """Initialize error."""
if TYPE_CHECKING:
value = str(value)
super().__init__( super().__init__(
translation_domain="homeassistant", translation_domain="homeassistant",
translation_key="max_length_exceeded", translation_key="max_length_exceeded",

View file

@ -38,6 +38,7 @@ from homeassistant.exceptions import TemplateError
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.event_type import EventType
from .device_registry import ( from .device_registry import (
EVENT_DEVICE_REGISTRY_UPDATED, EVENT_DEVICE_REGISTRY_UPDATED,
@ -90,7 +91,7 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
listeners_key: str listeners_key: str
callbacks_key: str callbacks_key: str
event_type: str event_type: EventType[_TypedDictT] | str
dispatcher_callable: Callable[ dispatcher_callable: Callable[
[ [
HomeAssistant, HomeAssistant,

View file

@ -0,0 +1,20 @@
"""Implementation for EventType.
Custom for type checking. See stub file.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Generic
from typing_extensions import TypeVar
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
class EventType(str, Generic[_DataT]):
"""Custom type for Event.event_type.
At runtime this is a generic subclass of str.
"""

View file

@ -0,0 +1,25 @@
"""Stub file for event_type. Provide overload for type checking."""
# ruff: noqa: PYI021 # Allow docstrings
from collections.abc import Mapping
from typing import Any, Generic
from typing_extensions import TypeVar
__all__ = [
"EventType",
]
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
class EventType(Generic[_DataT]):
"""Custom type for Event.event_type. At runtime delegated to str.
For type checkers pretend to be its own separate class.
"""
def __init__(self, value: str, /) -> None: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
def __eq__(self, value: object, /) -> bool: ...
def __getitem__(self, index: int) -> str: ...

View file

@ -0,0 +1,25 @@
"""Test EventType implementation."""
from __future__ import annotations
import orjson
from homeassistant.util.event_type import EventType
def test_compatibility_with_str() -> None:
"""Test EventType. At runtime it should be (almost) fully compatible with str."""
event = EventType("Hello World")
assert event == "Hello World"
assert len(event) == 11
assert hash(event) == hash("Hello World")
d: dict[str | EventType, int] = {EventType("key"): 2}
assert d["key"] == 2
def test_json_dump() -> None:
"""Test EventType json dump with orjson."""
event = EventType("state_changed")
assert orjson.dumps({"event_type": event}) == b'{"event_type":"state_changed"}'