Improve recorder event typing (#115253)

This commit is contained in:
Marc Mueller 2024-04-09 02:56:18 +02:00 committed by GitHub
parent 4e983d710f
commit d4500cf945
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 30 additions and 19 deletions

View file

@ -30,7 +30,13 @@ from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
MATCH_ALL, MATCH_ALL,
) )
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
callback,
)
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
async_track_time_change, async_track_time_change,
async_track_time_interval, async_track_time_interval,
@ -862,12 +868,12 @@ class Recorder(threading.Thread):
self._guarded_process_one_task_or_event_or_recover(queue_.get()) self._guarded_process_one_task_or_event_or_recover(queue_.get())
def _pre_process_startup_events( def _pre_process_startup_events(
self, startup_task_or_events: list[RecorderTask | Event] self, startup_task_or_events: list[RecorderTask | Event[Any]]
) -> None: ) -> None:
"""Pre process startup events.""" """Pre process startup events."""
# Prime all the state_attributes and event_data caches # Prime all the state_attributes and event_data caches
# before we start processing events # before we start processing events
state_change_events: list[Event] = [] state_change_events: list[Event[EventStateChangedData]] = []
non_state_change_events: list[Event] = [] non_state_change_events: list[Event] = []
for task_or_event in startup_task_or_events: for task_or_event in startup_task_or_events:
@ -1019,7 +1025,7 @@ class Recorder(threading.Thread):
self.backlog, self.backlog,
) )
def _process_one_event(self, event: Event) -> None: def _process_one_event(self, event: Event[Any]) -> None:
if not self.enabled: if not self.enabled:
return return
if event.event_type == EVENT_STATE_CHANGED: if event.event_type == EVENT_STATE_CHANGED:
@ -1076,7 +1082,9 @@ class Recorder(threading.Thread):
self._add_to_session(session, dbevent) self._add_to_session(session, dbevent)
def _process_state_changed_event_into_session(self, event: Event) -> None: def _process_state_changed_event_into_session(
self, event: Event[EventStateChangedData]
) -> None:
"""Process a state_changed event into the session.""" """Process a state_changed event into the session."""
state_attributes_manager = self.state_attributes_manager state_attributes_manager = self.state_attributes_manager
states_meta_manager = self.states_meta_manager states_meta_manager = self.states_meta_manager

View file

@ -40,7 +40,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_ENTITY_ID, MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE, MAX_LENGTH_STATE_STATE,
) )
from homeassistant.core import Context, Event, EventOrigin, State from homeassistant.core import Context, Event, EventOrigin, EventStateChangedData, State
from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.json import ( from homeassistant.util.json import (
@ -478,10 +478,10 @@ class States(Base):
return date_time.isoformat(sep=" ", timespec="seconds") return date_time.isoformat(sep=" ", timespec="seconds")
@staticmethod @staticmethod
def from_event(event: Event) -> States: def from_event(event: Event[EventStateChangedData]) -> States:
"""Create object from a state_changed event.""" """Create object from a state_changed event."""
entity_id = event.data["entity_id"] entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state") state = event.data["new_state"]
dbstate = States( dbstate = States(
entity_id=entity_id, entity_id=entity_id,
attributes=None, attributes=None,
@ -576,13 +576,12 @@ class StateAttributes(Base):
@staticmethod @staticmethod
def shared_attrs_bytes_from_event( def shared_attrs_bytes_from_event(
event: Event, event: Event[EventStateChangedData],
dialect: SupportedDialect | None, dialect: SupportedDialect | None,
) -> bytes: ) -> bytes:
"""Create shared_attrs from a state_changed event.""" """Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine # None state means the state was removed from the state machine
if state is None: if (state := event.data["new_state"]) is None:
return b"{}" return b"{}"
if state_info := state.state_info: if state_info := state.state_info:
exclude_attrs = { exclude_attrs = {

View file

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event, EventStateChangedData
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from ..db_schema import StateAttributes from ..db_schema import StateAttributes
@ -38,7 +38,7 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
super().__init__(recorder, CACHE_SIZE) super().__init__(recorder, CACHE_SIZE)
self.active = True # always active self.active = True # always active
def serialize_from_event(self, event: Event) -> bytes | None: def serialize_from_event(self, event: Event[EventStateChangedData]) -> bytes | None:
"""Serialize event data.""" """Serialize event data."""
try: try:
return StateAttributes.shared_attrs_bytes_from_event( return StateAttributes.shared_attrs_bytes_from_event(
@ -47,12 +47,14 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
except JSON_ENCODE_EXCEPTIONS as ex: except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning( _LOGGER.warning(
"State is not JSON serializable: %s: %s", "State is not JSON serializable: %s: %s",
event.data.get("new_state"), event.data["new_state"],
ex, ex,
) )
return None return None
def load(self, events: list[Event], session: Session) -> None: def load(
self, events: list[Event[EventStateChangedData]], session: Session
) -> None:
"""Load the shared_attrs to attributes_ids mapping into memory from events. """Load the shared_attrs to attributes_ids mapping into memory from events.
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the

View file

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event, EventStateChangedData
from ..db_schema import StatesMeta from ..db_schema import StatesMeta
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
@ -28,7 +28,9 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
self._did_first_load = False self._did_first_load = False
super().__init__(recorder, CACHE_SIZE) super().__init__(recorder, CACHE_SIZE)
def load(self, events: list[Event], session: Session) -> None: def load(
self, events: list[Event[EventStateChangedData]], session: Session
) -> None:
"""Load the entity_id to metadata_id mapping into memory. """Load the entity_id to metadata_id mapping into memory.
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the
@ -37,9 +39,9 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
self._did_first_load = True self._did_first_load = True
self.get_many( self.get_many(
{ {
event.data["new_state"].entity_id new_state.entity_id
for event in events for event in events
if event.data.get("new_state") is not None if (new_state := event.data["new_state"]) is not None
}, },
session, session,
True, True,