diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index a8746a0a807..ad05cad3d54 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -11,7 +11,7 @@ import queue import sqlite3 import threading import time -from typing import Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import psutil_home_assistant as ha_psutil from sqlalchemy import create_engine, event as sqlalchemy_event, exc, select @@ -104,7 +104,6 @@ from .tasks import ( EntityIDPostMigrationTask, EventIdMigrationTask, EventsContextIDMigrationTask, - EventTask, EventTypeIDMigrationTask, ImportStatisticsTask, KeepAliveTask, @@ -189,7 +188,7 @@ class Recorder(threading.Thread): self.keep_days = keep_days self._hass_started: asyncio.Future[object] = hass.loop.create_future() self.commit_interval = commit_interval - self._queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue() + self._queue: queue.SimpleQueue[RecorderTask | Event] = queue.SimpleQueue() self.db_url = uri self.db_max_retries = db_max_retries self.db_retry_wait = db_retry_wait @@ -278,7 +277,7 @@ class Recorder(threading.Thread): raise RuntimeError("The database connection has not been established") return self._get_session() - def queue_task(self, task: RecorderTask) -> None: + def queue_task(self, task: RecorderTask | Event) -> None: """Add a task to the recorder queue.""" self._queue.put(task) @@ -306,7 +305,6 @@ class Recorder(threading.Thread): entity_filter = self.entity_filter exclude_event_types = self.exclude_event_types queue_put = self._queue.put_nowait - event_task = EventTask @callback def _event_listener(event: Event) -> None: @@ -315,23 +313,23 @@ class Recorder(threading.Thread): return if (entity_id := event.data.get(ATTR_ENTITY_ID)) is None: - queue_put(event_task(event)) + queue_put(event) return if isinstance(entity_id, str): if entity_filter(entity_id): - queue_put(event_task(event)) + queue_put(event) return if isinstance(entity_id, list): for eid in entity_id: if entity_filter(eid): - queue_put(event_task(event)) + queue_put(event) return return # Unknown what it is. - queue_put(event_task(event)) + queue_put(event) self._event_listener = self.hass.bus.async_listen( MATCH_ALL, @@ -857,31 +855,35 @@ class Recorder(threading.Thread): # with a commit every time the event time # has changed. This reduces the disk io. queue_ = self._queue - startup_tasks: list[RecorderTask] = [] - while not queue_.empty() and (task := queue_.get_nowait()): - startup_tasks.append(task) - self._pre_process_startup_tasks(startup_tasks) - for task in startup_tasks: - self._guarded_process_one_task_or_recover(task) + startup_task_or_events: list[RecorderTask | Event] = [] + while not queue_.empty() and (task_or_event := queue_.get_nowait()): + startup_task_or_events.append(task_or_event) + self._pre_process_startup_events(startup_task_or_events) + for task in startup_task_or_events: + self._guarded_process_one_task_or_event_or_recover(task) # Clear startup tasks since this thread runs forever # and we don't want to hold them in memory - del startup_tasks + del startup_task_or_events self.stop_requested = False while not self.stop_requested: - self._guarded_process_one_task_or_recover(queue_.get()) + self._guarded_process_one_task_or_event_or_recover(queue_.get()) - def _pre_process_startup_tasks(self, startup_tasks: list[RecorderTask]) -> None: - """Pre process startup tasks.""" + def _pre_process_startup_events( + self, startup_task_or_events: list[RecorderTask | Event] + ) -> None: + """Pre process startup events.""" # Prime all the state_attributes and event_data caches # before we start processing events state_change_events: list[Event] = [] non_state_change_events: list[Event] = [] - for task in startup_tasks: - if isinstance(task, EventTask): - event_ = task.event + for task_or_event in startup_task_or_events: + # Event is never subclassed so we can + # use a fast type check + if type(task_or_event) is Event: # noqa: E721 + event_ = task_or_event if event_.event_type == EVENT_STATE_CHANGED: state_change_events.append(event_) else: @@ -894,20 +896,31 @@ class Recorder(threading.Thread): self.states_meta_manager.load(state_change_events, session) self.state_attributes_manager.load(state_change_events, session) - def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None: + def _guarded_process_one_task_or_event_or_recover( + self, task: RecorderTask | Event + ) -> None: """Process a task, guarding against exceptions to ensure the loop does not collapse.""" _LOGGER.debug("Processing task: %s", task) try: - self._process_one_task_or_recover(task) + self._process_one_task_or_event_or_recover(task) except Exception as err: # pylint: disable=broad-except _LOGGER.exception("Error while processing event %s: %s", task, err) - def _process_one_task_or_recover(self, task: RecorderTask) -> None: - """Process an event, reconnect, or recover a malformed database.""" + def _process_one_task_or_event_or_recover(self, task: RecorderTask | Event) -> None: + """Process a task or event, reconnect, or recover a malformed database.""" try: + # Almost everything coming in via the queue + # is an Event so we can process it directly + # and since its never subclassed, we can + # use a fast type check + if type(task) is Event: # noqa: E721 + self._process_one_event(task) + return # If its not an event, commit everything # that is pending before running the task - if task.commit_before: + if TYPE_CHECKING: + assert isinstance(task, RecorderTask) + if not task.commit_before: self._commit_event_session_or_retry() return task.run(self) except exc.DatabaseError as err: diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 07be6202a0c..c062eb3915f 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -10,7 +10,6 @@ import logging import threading from typing import TYPE_CHECKING, Any -from homeassistant.core import Event from homeassistant.helpers.typing import UndefinedType from . import entity_registry, purge, statistics @@ -268,19 +267,6 @@ class StopTask(RecorderTask): instance.stop_requested = True -@dataclass(slots=True) -class EventTask(RecorderTask): - """An event to be processed.""" - - event: Event - commit_before = False - - def run(self, instance: Recorder) -> None: - """Handle the task.""" - # pylint: disable-next=[protected-access] - instance._process_one_event(self.event) - - @dataclass(slots=True) class KeepAliveTask(RecorderTask): """A keep alive to be sent.""" diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index ede5bc32a6f..db4074a8fdb 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -399,7 +399,7 @@ async def test_schema_migrate( ), patch( "homeassistant.components.recorder.Recorder._process_non_state_changed_event_into_session", ), patch( - "homeassistant.components.recorder.Recorder._pre_process_startup_tasks", + "homeassistant.components.recorder.Recorder._pre_process_startup_events", ): recorder_helper.async_initialize_recorder(hass) hass.async_create_task(