Add improved typing for event fire and listen methods (#114906)
* Add EventType implementation * Update integrations for EventType * Change state_changed to EventType * Fix tests * Remove runtime impact * Add tests * Move to stub file * Apply pre-commit to stub files * Fix ruff PYI checks --------- Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
d007b175c5
commit
a0e6fd6ec5
21 changed files with 182 additions and 62 deletions
|
@ -6,7 +6,7 @@ repos:
|
|||
args:
|
||||
- --fix
|
||||
- id: ruff-format
|
||||
files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.py$
|
||||
files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.(py|pyi)$
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.6
|
||||
hooks:
|
||||
|
@ -63,7 +63,7 @@ repos:
|
|||
language: script
|
||||
types: [python]
|
||||
require_serial: true
|
||||
files: ^(homeassistant|pylint)/.+\.py$
|
||||
files: ^(homeassistant|pylint)/.+\.(py|pyi)$
|
||||
- id: pylint
|
||||
name: pylint
|
||||
entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
|
||||
from homeassistant.const import (
|
||||
EVENT_COMPONENT_LOADED,
|
||||
|
@ -21,10 +21,11 @@ from homeassistant.helpers.area_registry import EVENT_AREA_REGISTRY_UPDATED
|
|||
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
|
||||
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||
from homeassistant.helpers.issue_registry import EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
# These are events that do not contain any sensitive data
|
||||
# Except for state_changed, which is handled accordingly.
|
||||
SUBSCRIBE_ALLOWLIST: Final[set[str]] = {
|
||||
SUBSCRIBE_ALLOWLIST: Final[set[EventType[Any] | str]] = {
|
||||
EVENT_AREA_REGISTRY_UPDATED,
|
||||
EVENT_COMPONENT_LOADED,
|
||||
EVENT_CORE_CONFIG_UPDATE,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.logbook import (
|
||||
LOGBOOK_ENTRY_ICON,
|
||||
|
@ -11,10 +12,11 @@ from homeassistant.components.logbook import (
|
|||
)
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from . import DOMAIN
|
||||
|
||||
EVENT_TO_NAME = {
|
||||
EVENT_TO_NAME: dict[EventType[Any] | str, str] = {
|
||||
EVENT_HOMEASSISTANT_STOP: "stopped",
|
||||
EVENT_HOMEASSISTANT_START: "started",
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ from homeassistant.helpers.integration_platform import (
|
|||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from . import rest_api, websocket_api
|
||||
from .const import ( # noqa: F401
|
||||
|
@ -134,7 +135,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
entities_filter = None
|
||||
|
||||
external_events: dict[
|
||||
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
|
||||
EventType[Any] | str,
|
||||
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
|
||||
] = {}
|
||||
hass.data[DOMAIN] = LogbookConfig(external_events, filters, entities_filter)
|
||||
websocket_api.async_setup(hass)
|
||||
|
|
|
@ -26,6 +26,7 @@ from homeassistant.core import (
|
|||
)
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.event import async_track_state_change_event
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN
|
||||
from .models import LogbookConfig
|
||||
|
@ -63,7 +64,7 @@ def _async_config_entries_for_ids(
|
|||
|
||||
def async_determine_event_types(
|
||||
hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None
|
||||
) -> tuple[str, ...]:
|
||||
) -> tuple[EventType[Any] | str, ...]:
|
||||
"""Reduce the event types based on the entity ids and device ids."""
|
||||
logbook_config: LogbookConfig = hass.data[DOMAIN]
|
||||
external_events = logbook_config.external_events
|
||||
|
@ -81,7 +82,7 @@ def async_determine_event_types(
|
|||
# to add them since we have historically included
|
||||
# them when matching only on entities
|
||||
#
|
||||
intrested_event_types: set[str] = {
|
||||
intrested_event_types: set[EventType[Any] | str] = {
|
||||
external_event
|
||||
for external_event, domain_call in external_events.items()
|
||||
if domain_call[0] in interested_domains
|
||||
|
@ -160,7 +161,7 @@ def async_subscribe_events(
|
|||
hass: HomeAssistant,
|
||||
subscriptions: list[CALLBACK_TYPE],
|
||||
target: Callable[[Event[Any]], None],
|
||||
event_types: tuple[str, ...],
|
||||
event_types: tuple[EventType[Any] | str, ...],
|
||||
entities_filter: Callable[[str], bool] | None,
|
||||
entity_ids: list[str] | None,
|
||||
device_ids: list[str] | None,
|
||||
|
|
|
@ -18,6 +18,7 @@ from homeassistant.components.recorder.models import (
|
|||
)
|
||||
from homeassistant.const import ATTR_ICON, EVENT_STATE_CHANGED
|
||||
from homeassistant.core import Context, Event, State, callback
|
||||
from homeassistant.util.event_type import EventType
|
||||
from homeassistant.util.json import json_loads
|
||||
from homeassistant.util.ulid import ulid_to_bytes
|
||||
|
||||
|
@ -27,7 +28,8 @@ class LogbookConfig:
|
|||
"""Configuration for the logbook integration."""
|
||||
|
||||
external_events: dict[
|
||||
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
|
||||
EventType[Any] | str,
|
||||
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
|
||||
]
|
||||
sqlalchemy_filter: Filters | None = None
|
||||
entity_filter: Callable[[str], bool] | None = None
|
||||
|
@ -66,7 +68,7 @@ class LazyEventPartialState:
|
|||
)
|
||||
|
||||
@cached_property
|
||||
def event_type(self) -> str | None:
|
||||
def event_type(self) -> EventType[Any] | str | None:
|
||||
"""Return the event type."""
|
||||
return self.row.event_type
|
||||
|
||||
|
@ -110,7 +112,7 @@ class EventAsRow:
|
|||
icon: str | None = None
|
||||
context_user_id_bin: bytes | None = None
|
||||
context_parent_id_bin: bytes | None = None
|
||||
event_type: str | None = None
|
||||
event_type: EventType[Any] | str | None = None
|
||||
state: str | None = None
|
||||
context_only: None = None
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import HomeAssistant, split_entity_id
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from .const import (
|
||||
ATTR_MESSAGE,
|
||||
|
@ -75,7 +76,8 @@ class LogbookRun:
|
|||
|
||||
context_lookup: dict[bytes | None, Row | EventAsRow | None]
|
||||
external_events: dict[
|
||||
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
|
||||
EventType[Any] | str,
|
||||
tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]],
|
||||
]
|
||||
event_cache: EventCache
|
||||
entity_name_cache: EntityNameCache
|
||||
|
@ -90,7 +92,7 @@ class EventProcessor:
|
|||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
event_types: tuple[str, ...],
|
||||
event_types: tuple[EventType[Any] | str, ...],
|
||||
entity_ids: list[str] | None = None,
|
||||
device_ids: list[str] | None = None,
|
||||
context_id: str | None = None,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.logbook import (
|
||||
LOGBOOK_ENTRY_ENTITY_ID,
|
||||
|
@ -12,6 +13,7 @@ from homeassistant.components.logbook import (
|
|||
)
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
@ -21,7 +23,7 @@ IOS_EVENT_ZONE_EXITED = "ios.zone_exited"
|
|||
ATTR_ZONE = "zone"
|
||||
ATTR_SOURCE_DEVICE_NAME = "sourceDeviceName"
|
||||
ATTR_SOURCE_DEVICE_ID = "sourceDeviceID"
|
||||
EVENT_TO_DESCRIPTION = {
|
||||
EVENT_TO_DESCRIPTION: dict[EventType[Any] | str, str] = {
|
||||
IOS_EVENT_ZONE_ENTERED: "entered zone",
|
||||
IOS_EVENT_ZONE_EXITED: "exited zone",
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ from homeassistant.helpers.integration_platform import (
|
|||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from . import entity_registry, websocket_api
|
||||
from .const import ( # noqa: F401
|
||||
|
@ -146,7 +147,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass_config_path=hass.config.path(DEFAULT_DB_FILE)
|
||||
)
|
||||
exclude = conf[CONF_EXCLUDE]
|
||||
exclude_event_types: set[str] = set(exclude.get(CONF_EVENT_TYPES, []))
|
||||
exclude_event_types: set[EventType[Any] | str] = set(
|
||||
exclude.get(CONF_EVENT_TYPES, [])
|
||||
)
|
||||
if EVENT_STATE_CHANGED in exclude_event_types:
|
||||
_LOGGER.error("State change events cannot be excluded, use a filter instead")
|
||||
exclude_event_types.remove(EVENT_STATE_CHANGED)
|
||||
|
|
|
@ -40,6 +40,7 @@ from homeassistant.helpers.start import async_at_started
|
|||
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.enum import try_parse_enum
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from . import migration, statistics
|
||||
from .const import (
|
||||
|
@ -173,7 +174,7 @@ class Recorder(threading.Thread):
|
|||
db_max_retries: int,
|
||||
db_retry_wait: int,
|
||||
entity_filter: Callable[[str], bool],
|
||||
exclude_event_types: set[str],
|
||||
exclude_event_types: set[EventType[Any] | str],
|
||||
) -> None:
|
||||
"""Initialize the recorder."""
|
||||
threading.Thread.__init__(self, name="Recorder")
|
||||
|
|
|
@ -2,9 +2,13 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
|
||||
def extract_event_type_ids(
|
||||
event_type_to_event_type_id: dict[str, int | None],
|
||||
event_type_to_event_type_id: dict[EventType[Any] | str, int | None],
|
||||
) -> list[int]:
|
||||
"""Extract event_type ids from event_type_to_event_type_id."""
|
||||
return [
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
"""Managers for each table."""
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from lru import LRU
|
||||
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core import Recorder
|
||||
|
||||
|
@ -13,7 +15,7 @@ _DataT = TypeVar("_DataT")
|
|||
class BaseTableManager(Generic[_DataT]):
|
||||
"""Base class for table managers."""
|
||||
|
||||
_id_map: "LRU[str, int]"
|
||||
_id_map: "LRU[EventType[Any] | str, int]"
|
||||
|
||||
def __init__(self, recorder: "Recorder") -> None:
|
||||
"""Initialize the table manager.
|
||||
|
@ -24,7 +26,7 @@ class BaseTableManager(Generic[_DataT]):
|
|||
"""
|
||||
self.active = False
|
||||
self.recorder = recorder
|
||||
self._pending: dict[str, _DataT] = {}
|
||||
self._pending: dict[EventType[Any] | str, _DataT] = {}
|
||||
|
||||
def get_from_cache(self, data: str) -> int | None:
|
||||
"""Resolve data to the id without accessing the underlying database.
|
||||
|
@ -34,7 +36,7 @@ class BaseTableManager(Generic[_DataT]):
|
|||
"""
|
||||
return self._id_map.get(data)
|
||||
|
||||
def get_pending(self, shared_data: str) -> _DataT | None:
|
||||
def get_pending(self, shared_data: EventType[Any] | str) -> _DataT | None:
|
||||
"""Get pending data that have not be assigned ids yet.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
|
|
|
@ -3,12 +3,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from lru import LRU
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.core import Event
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from ..db_schema import EventTypes
|
||||
from ..queries import find_event_type_ids
|
||||
|
@ -29,7 +30,9 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
|
|||
def __init__(self, recorder: Recorder) -> None:
|
||||
"""Initialize the event type manager."""
|
||||
super().__init__(recorder, CACHE_SIZE)
|
||||
self._non_existent_event_types: LRU[str, None] = LRU(CACHE_SIZE)
|
||||
self._non_existent_event_types: LRU[EventType[Any] | str, None] = LRU(
|
||||
CACHE_SIZE
|
||||
)
|
||||
|
||||
def load(self, events: list[Event], session: Session) -> None:
|
||||
"""Load the event_type to event_type_ids mapping into memory.
|
||||
|
@ -44,7 +47,10 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
|
|||
)
|
||||
|
||||
def get(
|
||||
self, event_type: str, session: Session, from_recorder: bool = False
|
||||
self,
|
||||
event_type: EventType[Any] | str,
|
||||
session: Session,
|
||||
from_recorder: bool = False,
|
||||
) -> int | None:
|
||||
"""Resolve event_type to the event_type_id.
|
||||
|
||||
|
@ -54,16 +60,19 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
|
|||
return self.get_many((event_type,), session)[event_type]
|
||||
|
||||
def get_many(
|
||||
self, event_types: Iterable[str], session: Session, from_recorder: bool = False
|
||||
) -> dict[str, int | None]:
|
||||
self,
|
||||
event_types: Iterable[EventType[Any] | str],
|
||||
session: Session,
|
||||
from_recorder: bool = False,
|
||||
) -> dict[EventType[Any] | str, int | None]:
|
||||
"""Resolve event_types to event_type_ids.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
results: dict[str, int | None] = {}
|
||||
missing: list[str] = []
|
||||
non_existent: list[str] = []
|
||||
results: dict[EventType[Any] | str, int | None] = {}
|
||||
missing: list[EventType[Any] | str] = []
|
||||
non_existent: list[EventType[Any] | str] = []
|
||||
|
||||
for event_type in event_types:
|
||||
if (event_type_id := self._id_map.get(event_type)) is None:
|
||||
|
@ -123,7 +132,7 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
|
|||
self.clear_non_existent(event_type)
|
||||
self._pending.clear()
|
||||
|
||||
def clear_non_existent(self, event_type: str) -> None:
|
||||
def clear_non_existent(self, event_type: EventType[Any] | str) -> None:
|
||||
"""Clear a non-existent event type from the cache.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
|
|
|
@ -12,6 +12,7 @@ import threading
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from homeassistant.helpers.typing import UndefinedType
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from . import entity_registry, purge, statistics
|
||||
from .const import DOMAIN
|
||||
|
@ -459,7 +460,7 @@ class EventIdMigrationTask(RecorderTask):
|
|||
class RefreshEventTypesTask(RecorderTask):
|
||||
"""An object to insert into the recorder queue to refresh event types."""
|
||||
|
||||
event_types: list[str]
|
||||
event_types: list[EventType[Any] | str]
|
||||
|
||||
def run(self, instance: Recorder) -> None:
|
||||
"""Refresh event types."""
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from enum import StrEnum
|
||||
from functools import partial
|
||||
from typing import Final
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from .helpers.deprecation import (
|
||||
DeprecatedConstant,
|
||||
|
@ -13,8 +13,12 @@ from .helpers.deprecation import (
|
|||
check_if_deprecated_constant,
|
||||
dir_with_deprecated_constants,
|
||||
)
|
||||
from .util.event_type import EventType
|
||||
from .util.signal_type import SignalType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import EventStateChangedData
|
||||
|
||||
APPLICATION_NAME: Final = "HomeAssistant"
|
||||
MAJOR_VERSION: Final = 2024
|
||||
MINOR_VERSION: Final = 5
|
||||
|
@ -306,7 +310,7 @@ EVENT_LOGBOOK_ENTRY: Final = "logbook_entry"
|
|||
EVENT_LOGGING_CHANGED: Final = "logging_changed"
|
||||
EVENT_SERVICE_REGISTERED: Final = "service_registered"
|
||||
EVENT_SERVICE_REMOVED: Final = "service_removed"
|
||||
EVENT_STATE_CHANGED: Final = "state_changed"
|
||||
EVENT_STATE_CHANGED: EventType[EventStateChangedData] = EventType("state_changed")
|
||||
EVENT_STATE_REPORTED: Final = "state_reported"
|
||||
EVENT_THEMES_UPDATED: Final = "themes_updated"
|
||||
EVENT_PANELS_UPDATED: Final = "panels_updated"
|
||||
|
|
|
@ -102,6 +102,7 @@ from .util.async_ import (
|
|||
run_callback_threadsafe,
|
||||
shutdown_run_callback_threadsafe,
|
||||
)
|
||||
from .util.event_type import EventType
|
||||
from .util.executor import InterruptibleThreadPoolExecutor
|
||||
from .util.json import JsonObjectType
|
||||
from .util.read_only_dict import ReadOnlyDict
|
||||
|
@ -1216,7 +1217,7 @@ class Event(Generic[_DataT]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
event_type: str,
|
||||
event_type: EventType[_DataT] | str,
|
||||
data: _DataT | None = None,
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
time_fired_timestamp: float | None = None,
|
||||
|
@ -1290,7 +1291,7 @@ class Event(Generic[_DataT]):
|
|||
|
||||
|
||||
def _event_repr(
|
||||
event_type: str, origin: EventOrigin, data: Mapping[str, Any] | None
|
||||
event_type: EventType[_DataT] | str, origin: EventOrigin, data: _DataT | None
|
||||
) -> str:
|
||||
"""Return the representation."""
|
||||
if data:
|
||||
|
@ -1307,13 +1308,13 @@ _FilterableJobType = tuple[
|
|||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _OneTimeListener:
|
||||
class _OneTimeListener(Generic[_DataT]):
|
||||
hass: HomeAssistant
|
||||
listener_job: HassJob[[Event], Coroutine[Any, Any, None] | None]
|
||||
listener_job: HassJob[[Event[_DataT]], Coroutine[Any, Any, None] | None]
|
||||
remove: CALLBACK_TYPE | None = None
|
||||
|
||||
@callback
|
||||
def __call__(self, event: Event) -> None:
|
||||
def __call__(self, event: Event[_DataT]) -> None:
|
||||
"""Remove listener from event bus and then fire listener."""
|
||||
if not self.remove:
|
||||
# If the listener was already removed, we don't need to do anything
|
||||
|
@ -1341,7 +1342,7 @@ class EventBus:
|
|||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize a new event bus."""
|
||||
self._listeners: dict[str, list[_FilterableJobType[Any]]] = {}
|
||||
self._listeners: dict[EventType[Any] | str, list[_FilterableJobType[Any]]] = {}
|
||||
self._match_all_listeners: list[_FilterableJobType[Any]] = []
|
||||
self._listeners[MATCH_ALL] = self._match_all_listeners
|
||||
self._hass = hass
|
||||
|
@ -1356,7 +1357,7 @@ class EventBus:
|
|||
self._debug = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
|
||||
@callback
|
||||
def async_listeners(self) -> dict[str, int]:
|
||||
def async_listeners(self) -> dict[EventType[Any] | str, int]:
|
||||
"""Return dictionary with events and the number of listeners.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -1364,14 +1365,14 @@ class EventBus:
|
|||
return {key: len(listeners) for key, listeners in self._listeners.items()}
|
||||
|
||||
@property
|
||||
def listeners(self) -> dict[str, int]:
|
||||
def listeners(self) -> dict[EventType[Any] | str, int]:
|
||||
"""Return dictionary with events and the number of listeners."""
|
||||
return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
|
||||
|
||||
def fire(
|
||||
self,
|
||||
event_type: str,
|
||||
event_data: Mapping[str, Any] | None = None,
|
||||
event_type: EventType[_DataT] | str,
|
||||
event_data: _DataT | None = None,
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
context: Context | None = None,
|
||||
) -> None:
|
||||
|
@ -1383,8 +1384,8 @@ class EventBus:
|
|||
@callback
|
||||
def async_fire(
|
||||
self,
|
||||
event_type: str,
|
||||
event_data: Mapping[str, Any] | None = None,
|
||||
event_type: EventType[_DataT] | str,
|
||||
event_data: _DataT | None = None,
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
context: Context | None = None,
|
||||
time_fired: float | None = None,
|
||||
|
@ -1402,8 +1403,8 @@ class EventBus:
|
|||
@callback
|
||||
def _async_fire(
|
||||
self,
|
||||
event_type: str,
|
||||
event_data: Mapping[str, Any] | None = None,
|
||||
event_type: EventType[_DataT] | str,
|
||||
event_data: _DataT | None = None,
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
context: Context | None = None,
|
||||
time_fired: float | None = None,
|
||||
|
@ -1431,7 +1432,7 @@ class EventBus:
|
|||
if not listeners:
|
||||
return
|
||||
|
||||
event: Event | None = None
|
||||
event: Event[_DataT] | None = None
|
||||
|
||||
for job, event_filter, run_immediately in listeners:
|
||||
if event_filter is not None:
|
||||
|
@ -1461,8 +1462,8 @@ class EventBus:
|
|||
|
||||
def listen(
|
||||
self,
|
||||
event_type: str,
|
||||
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None],
|
||||
event_type: EventType[_DataT] | str,
|
||||
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen for all events or events of a specific type.
|
||||
|
||||
|
@ -1482,7 +1483,7 @@ class EventBus:
|
|||
@callback
|
||||
def async_listen(
|
||||
self,
|
||||
event_type: str,
|
||||
event_type: EventType[_DataT] | str,
|
||||
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
|
||||
event_filter: Callable[[_DataT], bool] | None = None,
|
||||
run_immediately: bool = True,
|
||||
|
@ -1524,7 +1525,9 @@ class EventBus:
|
|||
|
||||
@callback
|
||||
def _async_listen_filterable_job(
|
||||
self, event_type: str, filterable_job: _FilterableJobType[Any]
|
||||
self,
|
||||
event_type: EventType[_DataT] | str,
|
||||
filterable_job: _FilterableJobType[_DataT],
|
||||
) -> CALLBACK_TYPE:
|
||||
self._listeners.setdefault(event_type, []).append(filterable_job)
|
||||
return functools.partial(
|
||||
|
@ -1533,8 +1536,8 @@ class EventBus:
|
|||
|
||||
def listen_once(
|
||||
self,
|
||||
event_type: str,
|
||||
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None],
|
||||
event_type: EventType[_DataT] | str,
|
||||
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen once for event of a specific type.
|
||||
|
||||
|
@ -1556,8 +1559,8 @@ class EventBus:
|
|||
@callback
|
||||
def async_listen_once(
|
||||
self,
|
||||
event_type: str,
|
||||
listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None],
|
||||
event_type: EventType[_DataT] | str,
|
||||
listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None],
|
||||
run_immediately: bool = True,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen once for event of a specific type.
|
||||
|
@ -1569,7 +1572,9 @@ class EventBus:
|
|||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
one_time_listener = _OneTimeListener(self._hass, HassJob(listener))
|
||||
one_time_listener: _OneTimeListener[_DataT] = _OneTimeListener(
|
||||
self._hass, HassJob(listener)
|
||||
)
|
||||
remove = self._async_listen_filterable_job(
|
||||
event_type,
|
||||
(
|
||||
|
@ -1587,7 +1592,9 @@ class EventBus:
|
|||
|
||||
@callback
|
||||
def _async_remove_listener(
|
||||
self, event_type: str, filterable_job: _FilterableJobType
|
||||
self,
|
||||
event_type: EventType[_DataT] | str,
|
||||
filterable_job: _FilterableJobType[_DataT],
|
||||
) -> None:
|
||||
"""Remove a listener of a specific event_type.
|
||||
|
||||
|
|
|
@ -4,7 +4,9 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .util.event_type import EventType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Context
|
||||
|
@ -271,8 +273,12 @@ class ServiceNotFound(HomeAssistantError):
|
|||
class MaxLengthExceeded(HomeAssistantError):
|
||||
"""Raised when a property value has exceeded the max character length."""
|
||||
|
||||
def __init__(self, value: str, property_name: str, max_length: int) -> None:
|
||||
def __init__(
|
||||
self, value: EventType[Any] | str, property_name: str, max_length: int
|
||||
) -> None:
|
||||
"""Initialize error."""
|
||||
if TYPE_CHECKING:
|
||||
value = str(value)
|
||||
super().__init__(
|
||||
translation_domain="homeassistant",
|
||||
translation_key="max_length_exceeded",
|
||||
|
|
|
@ -38,6 +38,7 @@ from homeassistant.exceptions import TemplateError
|
|||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
from .device_registry import (
|
||||
EVENT_DEVICE_REGISTRY_UPDATED,
|
||||
|
@ -90,7 +91,7 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
|
|||
|
||||
listeners_key: str
|
||||
callbacks_key: str
|
||||
event_type: str
|
||||
event_type: EventType[_TypedDictT] | str
|
||||
dispatcher_callable: Callable[
|
||||
[
|
||||
HomeAssistant,
|
||||
|
|
20
homeassistant/util/event_type.py
Normal file
20
homeassistant/util/event_type.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
"""Implementation for EventType.
|
||||
|
||||
Custom for type checking. See stub file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
|
||||
|
||||
|
||||
class EventType(str, Generic[_DataT]):
|
||||
"""Custom type for Event.event_type.
|
||||
|
||||
At runtime this is a generic subclass of str.
|
||||
"""
|
25
homeassistant/util/event_type.pyi
Normal file
25
homeassistant/util/event_type.pyi
Normal file
|
@ -0,0 +1,25 @@
|
|||
"""Stub file for event_type. Provide overload for type checking."""
|
||||
# ruff: noqa: PYI021 # Allow docstrings
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
__all__ = [
|
||||
"EventType",
|
||||
]
|
||||
|
||||
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
|
||||
|
||||
class EventType(Generic[_DataT]):
|
||||
"""Custom type for Event.event_type. At runtime delegated to str.
|
||||
|
||||
For type checkers pretend to be its own separate class.
|
||||
"""
|
||||
|
||||
def __init__(self, value: str, /) -> None: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __eq__(self, value: object, /) -> bool: ...
|
||||
def __getitem__(self, index: int) -> str: ...
|
25
tests/util/test_event_type.py
Normal file
25
tests/util/test_event_type.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
"""Test EventType implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import orjson
|
||||
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
||||
|
||||
def test_compatibility_with_str() -> None:
|
||||
"""Test EventType. At runtime it should be (almost) fully compatible with str."""
|
||||
|
||||
event = EventType("Hello World")
|
||||
assert event == "Hello World"
|
||||
assert len(event) == 11
|
||||
assert hash(event) == hash("Hello World")
|
||||
d: dict[str | EventType, int] = {EventType("key"): 2}
|
||||
assert d["key"] == 2
|
||||
|
||||
|
||||
def test_json_dump() -> None:
|
||||
"""Test EventType json dump with orjson."""
|
||||
|
||||
event = EventType("state_changed")
|
||||
assert orjson.dumps({"event_type": event}) == b'{"event_type":"state_changed"}'
|
Loading…
Add table
Reference in a new issue