From e379aa23bdc3ec469fd363e300ec329b85da3158 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 15 Mar 2023 15:26:29 -1000 Subject: [PATCH] Migrate StateAttributes to use a table manager (#89760) Co-authored-by: Paulus Schoutsen --- homeassistant/components/recorder/core.py | 142 ++++------------ homeassistant/components/recorder/purge.py | 24 +-- homeassistant/components/recorder/queries.py | 11 -- .../recorder/table_managers/__init__.py | 66 +++++++- .../recorder/table_managers/event_data.py | 34 +--- .../recorder/table_managers/event_types.py | 26 +-- .../table_managers/state_attributes.py | 160 ++++++++++++++++++ .../recorder/table_managers/states_meta.py | 26 +-- tests/components/recorder/test_init.py | 12 +- 9 files changed, 274 insertions(+), 227 deletions(-) create mode 100644 homeassistant/components/recorder/table_managers/state_attributes.py diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 26cd5c3b889..e7fdf645812 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -11,10 +11,9 @@ import queue import sqlite3 import threading import time -from typing import Any, TypeVar, cast +from typing import Any, TypeVar import async_timeout -from lru import LRU # pylint: disable=no-name-in-module from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError @@ -30,7 +29,6 @@ from homeassistant.const import ( MATCH_ALL, ) from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback -from homeassistant.helpers.entity import entity_sources from homeassistant.helpers.event import ( async_track_time_change, async_track_time_interval, @@ -40,7 +38,6 @@ from homeassistant.helpers.start import async_at_started from homeassistant.helpers.typing import UNDEFINED, UndefinedType import homeassistant.util.dt as dt_util from homeassistant.util.enum import try_parse_enum -from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS from . import migration, statistics from .const import ( @@ -52,7 +49,6 @@ from .const import ( MAX_QUEUE_BACKLOG, MYSQLDB_PYMYSQL_URL_PREFIX, MYSQLDB_URL_PREFIX, - SQLITE_MAX_BIND_VARS, SQLITE_URL_PREFIX, SupportedDialect, ) @@ -79,8 +75,6 @@ from .models import ( ) from .pool import POOL_SIZE, MutexPool, RecorderPool from .queries import ( - find_shared_attributes_id, - get_shared_attributes, has_entity_ids_to_migrate, has_event_type_to_migrate, has_events_context_ids_to_migrate, @@ -89,6 +83,7 @@ from .queries import ( from .run_history import RunHistory from .table_managers.event_data import EventDataManager from .table_managers.event_types import EventTypeManager +from .table_managers.state_attributes import StateAttributesManager from .table_managers.states_meta import StatesMetaManager from .tasks import ( AdjustLRUSizeTask, @@ -115,7 +110,6 @@ from .tasks import ( ) from .util import ( build_mysqldb_conv, - chunked, dburl_to_path, end_incomplete_runs, is_second_sunday, @@ -136,15 +130,6 @@ DEFAULT_URL = "sqlite:///{hass_config_path}" # States and Events objects EXPIRE_AFTER_COMMITS = 120 -# The number of attribute ids to cache in memory -# -# Based on: -# - The number of overlapping attributes -# - How frequently states with overlapping attributes will change -# - How much memory our low end hardware has -STATE_ATTRIBUTES_ID_CACHE_SIZE = 2048 - - SHUTDOWN_TASK = object() COMMIT_TASK = CommitTask() @@ -206,7 +191,6 @@ class Recorder(threading.Thread): self._queue_watch = threading.Event() self.engine: Engine | None = None self.run_history = RunHistory() - self._entity_sources = entity_sources(hass) # The entity_filter is exposed on the recorder instance so that # it can be used to see if an entity is being recorded and is called @@ -217,11 +201,12 @@ class Recorder(threading.Thread): self.schema_version = 0 self._commits_without_expire = 0 self._old_states: dict[str | None, States] = {} - self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE) self.event_data_manager = EventDataManager(self) self.event_type_manager = EventTypeManager(self) self.states_meta_manager = StatesMetaManager(self) - self._pending_state_attributes: dict[str, StateAttributes] = {} + self.state_attributes_manager = StateAttributesManager( + self, exclude_attributes_by_domain + ) self._pending_expunge: list[States] = [] self.event_session: Session | None = None self._get_session: Callable[[], Session] | None = None @@ -231,7 +216,6 @@ class Recorder(threading.Thread): self.migration_is_live = False self._database_lock_task: DatabaseLockTask | None = None self._db_executor: DBInterruptibleThreadPoolExecutor | None = None - self._exclude_attributes_by_domain = exclude_attributes_by_domain self._event_listener: CALLBACK_TYPE | None = None self._queue_watcher: CALLBACK_TYPE | None = None @@ -507,11 +491,9 @@ class Recorder(threading.Thread): If the number of entities has increased, increase the size of the LRU cache to avoid thrashing. """ - state_attributes_lru = self._state_attributes_ids - current_size = state_attributes_lru.get_size() new_size = self.hass.states.async_entity_ids_count() * 2 - if new_size > current_size: - state_attributes_lru.set_size(new_size) + self.state_attributes_manager.adjust_lru_size(new_size) + self.states_meta_manager.adjust_lru_size(new_size) @callback def async_periodic_statistics(self) -> None: @@ -776,33 +758,10 @@ class Recorder(threading.Thread): non_state_change_events.append(event_) assert self.event_session is not None - self._pre_process_state_change_events(state_change_events) self.event_data_manager.load(non_state_change_events, self.event_session) self.event_type_manager.load(non_state_change_events, self.event_session) self.states_meta_manager.load(state_change_events, self.event_session) - - def _pre_process_state_change_events(self, events: list[Event]) -> None: - """Load startup state attributes from the database. - - Since the _state_attributes_ids cache is empty at startup - we restore it from the database to avoid having to look up - the attributes in the database for every state change - until its primed. - """ - assert self.event_session is not None - if hashes := { - StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes) - for event in events - if ( - shared_attrs_bytes := self._serialize_state_attributes_from_event(event) - ) - }: - with self.event_session.no_autoflush: - for hash_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS): - for id_, shared_attrs in self.event_session.execute( - get_shared_attributes(hash_chunk) - ).fetchall(): - self._state_attributes_ids[shared_attrs] = id_ + self.state_attributes_manager.load(state_change_events, self.event_session) def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None: """Process a task, guarding against exceptions to ensure the loop does not collapse.""" @@ -932,24 +891,6 @@ class Recorder(threading.Thread): if not self.commit_interval: self._commit_event_session_or_retry() - def _find_shared_attr_in_db(self, attr_hash: int, shared_attrs: str) -> int | None: - """Find shared attributes in the db from the hash and shared_attrs.""" - # - # Avoid the event session being flushed since it will - # commit all the pending events and states to the database. - # - # The lookup has already have checked to see if the data is cached - # or going to be written in the next commit so there is no - # need to flush before checking the database. - # - assert self.event_session is not None - with self.event_session.no_autoflush: - if attributes_id := self.event_session.execute( - find_shared_attributes_id(attr_hash, shared_attrs) - ).first(): - return cast(int, attributes_id[0]) - return None - def _process_non_state_changed_event_into_session(self, event: Event) -> None: """Process any event into the session except state changed.""" session = self.event_session @@ -996,67 +937,53 @@ class Recorder(threading.Thread): session.add(dbevent) - def _serialize_state_attributes_from_event(self, event: Event) -> bytes | None: - """Serialize state changed event data.""" - try: - return StateAttributes.shared_attrs_bytes_from_event( - event, - self._entity_sources, - self._exclude_attributes_by_domain, - self.dialect_name, - ) - except JSON_ENCODE_EXCEPTIONS as ex: - _LOGGER.warning( - "State is not JSON serializable: %s: %s", - event.data.get("new_state"), - ex, - ) - return None - def _process_state_changed_event_into_session(self, event: Event) -> None: """Process a state_changed event into the session.""" + state_attributes_manager = self.state_attributes_manager dbstate = States.from_event(event) if (entity_id := dbstate.entity_id) is None or not ( - shared_attrs_bytes := self._serialize_state_attributes_from_event(event) + shared_attrs_bytes := state_attributes_manager.serialize_from_event(event) ): return assert self.event_session is not None - event_session = self.event_session + session = self.event_session # Map the entity_id to the StatesMeta table states_meta_manager = self.states_meta_manager if pending_states_meta := states_meta_manager.get_pending(entity_id): dbstate.states_meta_rel = pending_states_meta - elif metadata_id := states_meta_manager.get(entity_id, event_session, True): + elif metadata_id := states_meta_manager.get(entity_id, session, True): dbstate.metadata_id = metadata_id else: states_meta = StatesMeta(entity_id=entity_id) states_meta_manager.add_pending(states_meta) - event_session.add(states_meta) + session.add(states_meta) dbstate.states_meta_rel = states_meta + # Map the event data to the StateAttributes table shared_attrs = shared_attrs_bytes.decode("utf-8") dbstate.attributes = None # Matching attributes found in the pending commit - if pending_attributes := self._pending_state_attributes.get(shared_attrs): - dbstate.state_attributes = pending_attributes + if pending_event_data := state_attributes_manager.get_pending(shared_attrs): + dbstate.state_attributes = pending_event_data # Matching attributes id found in the cache - elif attributes_id := self._state_attributes_ids.get(shared_attrs): + elif ( + attributes_id := state_attributes_manager.get_from_cache(shared_attrs) + ) or ( + (hash_ := StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)) + and ( + attributes_id := state_attributes_manager.get( + shared_attrs, hash_, session + ) + ) + ): dbstate.attributes_id = attributes_id else: - attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes) - # Matching attributes found in the database - if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs): - dbstate.attributes_id = attributes_id - self._state_attributes_ids[shared_attrs] = attributes_id # No matching attributes found, save them in the DB - else: - dbstate_attributes = StateAttributes( - shared_attrs=shared_attrs, hash=attr_hash - ) - dbstate.state_attributes = dbstate_attributes - self._pending_state_attributes[shared_attrs] = dbstate_attributes - self.event_session.add(dbstate_attributes) + dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_) + state_attributes_manager.add_pending(dbstate_attributes) + session.add(dbstate_attributes) + dbstate.state_attributes = dbstate_attributes if old_state := self._old_states.pop(entity_id, None): if old_state.state_id: @@ -1128,11 +1055,7 @@ class Recorder(threading.Thread): # and we now know the attributes_ids. We can save # many selects for matching attributes by loading them # into the LRU cache now. - for state_attr in self._pending_state_attributes.values(): - self._state_attributes_ids[ - state_attr.shared_attrs - ] = state_attr.attributes_id - self._pending_state_attributes = {} + self.state_attributes_manager.post_commit_pending() self.event_data_manager.post_commit_pending() self.event_type_manager.post_commit_pending() self.states_meta_manager.post_commit_pending() @@ -1158,8 +1081,7 @@ class Recorder(threading.Thread): def _close_event_session(self) -> None: """Close the event session.""" self._old_states.clear() - self._state_attributes_ids.clear() - self._pending_state_attributes.clear() + self.state_attributes_manager.reset() self.event_data_manager.reset() self.event_type_manager.reset() self.states_meta_manager.reset() diff --git a/homeassistant/components/recorder/purge.py b/homeassistant/components/recorder/purge.py index 5dffead5978..08122b9fba7 100644 --- a/homeassistant/components/recorder/purge.py +++ b/homeassistant/components/recorder/purge.py @@ -479,28 +479,6 @@ def _evict_purged_states_from_old_states_cache( old_states.pop(old_state_reversed[purged_state_id], None) -def _evict_purged_attributes_from_attributes_cache( - instance: Recorder, purged_attributes_ids: set[int] -) -> None: - """Evict purged attribute ids from the attribute ids cache.""" - # Make a map from attributes_id to the attributes json - state_attributes_ids = ( - instance._state_attributes_ids # pylint: disable=protected-access - ) - state_attributes_ids_reversed = { - attributes_id: attributes - for attributes, attributes_id in state_attributes_ids.items() - } - - # Evict any purged attributes from the state_attributes_ids cache - for purged_attribute_id in purged_attributes_ids.intersection( - state_attributes_ids_reversed - ): - state_attributes_ids.pop( - state_attributes_ids_reversed[purged_attribute_id], None - ) - - def _purge_batch_attributes_ids( instance: Recorder, session: Session, attributes_ids: set[int] ) -> None: @@ -512,7 +490,7 @@ def _purge_batch_attributes_ids( _LOGGER.debug("Deleted %s attribute states", deleted_rows) # Evict any entries in the state_attributes_ids cache referring to a purged state - _evict_purged_attributes_from_attributes_cache(instance, attributes_ids) + instance.state_attributes_manager.evict_purged(attributes_ids) def _purge_batch_data_ids( diff --git a/homeassistant/components/recorder/queries.py b/homeassistant/components/recorder/queries.py index 0882da9d48c..5a2c7040f43 100644 --- a/homeassistant/components/recorder/queries.py +++ b/homeassistant/components/recorder/queries.py @@ -74,17 +74,6 @@ def find_states_metadata_ids(entity_ids: Iterable[str]) -> StatementLambdaElemen ) -def find_shared_attributes_id( - data_hash: int, shared_attrs: str -) -> StatementLambdaElement: - """Find an attributes_id by hash and shared_attrs.""" - return lambda_stmt( - lambda: select(StateAttributes.attributes_id) - .filter(StateAttributes.hash == data_hash) - .filter(StateAttributes.shared_attrs == shared_attrs) - ) - - def _state_attrs_exist(attr: int | None) -> Select: """Check if a state attributes id exists in the states table.""" # https://github.com/sqlalchemy/sqlalchemy/issues/9189 diff --git a/homeassistant/components/recorder/table_managers/__init__.py b/homeassistant/components/recorder/table_managers/__init__.py index 50ea8f0e11f..e56ee4f3415 100644 --- a/homeassistant/components/recorder/table_managers/__init__.py +++ b/homeassistant/components/recorder/table_managers/__init__.py @@ -1,15 +1,75 @@ """Managers for each table.""" -from typing import TYPE_CHECKING +from collections.abc import MutableMapping +from typing import TYPE_CHECKING, Generic, TypeVar + +from lru import LRU # pylint: disable=no-name-in-module if TYPE_CHECKING: from ..core import Recorder +_DataT = TypeVar("_DataT") -class BaseTableManager: + +class BaseTableManager(Generic[_DataT]): """Base class for table managers.""" def __init__(self, recorder: "Recorder") -> None: - """Initialize the table manager.""" + """Initialize the table manager. + + The table manager is responsible for managing the id mappings + for a table. When data is committed to the database, the + manager will move the data from the pending to the id map. + """ self.active = False self.recorder = recorder + self._pending: dict[str, _DataT] = {} + self._id_map: MutableMapping[str, int] = {} + + def get_from_cache(self, data: str) -> int | None: + """Resolve data to the id without accessing the underlying database. + + This call is not thread-safe and must be called from the + recorder thread. + """ + return self._id_map.get(data) + + def get_pending(self, shared_data: str) -> _DataT | None: + """Get pending data that have not be assigned ids yet. + + This call is not thread-safe and must be called from the + recorder thread. + """ + return self._pending.get(shared_data) + + def reset(self) -> None: + """Reset after the database has been reset or changed. + + This call is not thread-safe and must be called from the + recorder thread. + """ + self._id_map.clear() + self._pending.clear() + + +class BaseLRUTableManager(BaseTableManager[_DataT]): + """Base class for LRU table managers.""" + + def __init__(self, recorder: "Recorder", lru_size: int) -> None: + """Initialize the LRU table manager. + + We keep track of the most recently used items + and evict the least recently used items when the cache is full. + """ + super().__init__(recorder) + self._id_map: MutableMapping[str, int] = LRU(lru_size) + + def adjust_lru_size(self, new_size: int) -> None: + """Adjust the LRU cache size. + + This call is not thread-safe and must be called from the + recorder thread. + """ + lru: LRU = self._id_map + if new_size > lru.get_size(): + lru.set_size(new_size) diff --git a/homeassistant/components/recorder/table_managers/event_data.py b/homeassistant/components/recorder/table_managers/event_data.py index c877f08f878..a99b25fe0b4 100644 --- a/homeassistant/components/recorder/table_managers/event_data.py +++ b/homeassistant/components/recorder/table_managers/event_data.py @@ -5,13 +5,12 @@ from collections.abc import Iterable import logging from typing import TYPE_CHECKING, cast -from lru import LRU # pylint: disable=no-name-in-module from sqlalchemy.orm.session import Session from homeassistant.core import Event from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS -from . import BaseTableManager +from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import EventData from ..queries import get_shared_event_datas @@ -26,14 +25,12 @@ CACHE_SIZE = 2048 _LOGGER = logging.getLogger(__name__) -class EventDataManager(BaseTableManager): +class EventDataManager(BaseLRUTableManager[EventData]): """Manage the EventData table.""" def __init__(self, recorder: Recorder) -> None: """Initialize the event type manager.""" - self._id_map: dict[str, int] = LRU(CACHE_SIZE) - self._pending: dict[str, EventData] = {} - super().__init__(recorder) + super().__init__(recorder, CACHE_SIZE) self.active = True # always active def serialize_from_event(self, event: Event) -> bytes | None: @@ -67,14 +64,6 @@ class EventDataManager(BaseTableManager): """ return self.get_many(((shared_data, data_hash),), session)[shared_data] - def get_from_cache(self, shared_data: str) -> int | None: - """Resolve shared_data to the data_id without accessing the underlying database. - - This call is not thread-safe and must be called from the - recorder thread. - """ - return self._id_map.get(shared_data) - def get_many( self, shared_data_data_hashs: Iterable[tuple[str, int]], session: Session ) -> dict[str, int | None]: @@ -116,14 +105,6 @@ class EventDataManager(BaseTableManager): return results - def get_pending(self, shared_data: str) -> EventData | None: - """Get pending EventData that have not be assigned ids yet. - - This call is not thread-safe and must be called from the - recorder thread. - """ - return self._pending.get(shared_data) - def add_pending(self, db_event_data: EventData) -> None: """Add a pending EventData that will be committed at the next interval. @@ -144,15 +125,6 @@ class EventDataManager(BaseTableManager): self._id_map[shared_data] = db_event_data.data_id self._pending.clear() - def reset(self) -> None: - """Reset the event manager after the database has been reset or changed. - - This call is not thread-safe and must be called from the - recorder thread. - """ - self._id_map.clear() - self._pending.clear() - def evict_purged(self, data_ids: set[int]) -> None: """Evict purged data_ids from the cache when they are no longer used. diff --git a/homeassistant/components/recorder/table_managers/event_types.py b/homeassistant/components/recorder/table_managers/event_types.py index b31382336cc..3cb3d9fad97 100644 --- a/homeassistant/components/recorder/table_managers/event_types.py +++ b/homeassistant/components/recorder/table_managers/event_types.py @@ -4,12 +4,11 @@ from __future__ import annotations from collections.abc import Iterable from typing import TYPE_CHECKING, cast -from lru import LRU # pylint: disable=no-name-in-module from sqlalchemy.orm.session import Session from homeassistant.core import Event -from . import BaseTableManager +from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import EventTypes from ..queries import find_event_type_ids @@ -22,14 +21,12 @@ if TYPE_CHECKING: CACHE_SIZE = 2048 -class EventTypeManager(BaseTableManager): +class EventTypeManager(BaseLRUTableManager[EventTypes]): """Manage the EventTypes table.""" def __init__(self, recorder: Recorder) -> None: """Initialize the event type manager.""" - self._id_map: dict[str, int] = LRU(CACHE_SIZE) - self._pending: dict[str, EventTypes] = {} - super().__init__(recorder) + super().__init__(recorder, CACHE_SIZE) def load(self, events: list[Event], session: Session) -> None: """Load the event_type to event_type_ids mapping into memory. @@ -80,14 +77,6 @@ class EventTypeManager(BaseTableManager): return results - def get_pending(self, event_type: str) -> EventTypes | None: - """Get pending EventTypes that have not be assigned ids yet. - - This call is not thread-safe and must be called from the - recorder thread. - """ - return self._pending.get(event_type) - def add_pending(self, db_event_type: EventTypes) -> None: """Add a pending EventTypes that will be committed at the next interval. @@ -108,15 +97,6 @@ class EventTypeManager(BaseTableManager): self._id_map[event_type] = db_event_types.event_type_id self._pending.clear() - def reset(self) -> None: - """Reset the event manager after the database has been reset or changed. - - This call is not thread-safe and must be called from the - recorder thread. - """ - self._id_map.clear() - self._pending.clear() - def evict_purged(self, event_types: Iterable[str]) -> None: """Evict purged event_types from the cache when they are no longer used. diff --git a/homeassistant/components/recorder/table_managers/state_attributes.py b/homeassistant/components/recorder/table_managers/state_attributes.py new file mode 100644 index 00000000000..7489a6f165d --- /dev/null +++ b/homeassistant/components/recorder/table_managers/state_attributes.py @@ -0,0 +1,160 @@ +"""Support managing StateAttributes.""" +from __future__ import annotations + +from collections.abc import Iterable +import logging +from typing import TYPE_CHECKING, cast + +from sqlalchemy.orm.session import Session + +from homeassistant.core import Event +from homeassistant.helpers.entity import entity_sources +from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS + +from . import BaseLRUTableManager +from ..const import SQLITE_MAX_BIND_VARS +from ..db_schema import StateAttributes +from ..queries import get_shared_attributes +from ..util import chunked + +if TYPE_CHECKING: + from ..core import Recorder + +# The number of attribute ids to cache in memory +# +# Based on: +# - The number of overlapping attributes +# - How frequently states with overlapping attributes will change +# - How much memory our low end hardware has +CACHE_SIZE = 2048 + +_LOGGER = logging.getLogger(__name__) + + +class StateAttributesManager(BaseLRUTableManager[StateAttributes]): + """Manage the StateAttributes table.""" + + def __init__( + self, recorder: Recorder, exclude_attributes_by_domain: dict[str, set[str]] + ) -> None: + """Initialize the event type manager.""" + super().__init__(recorder, CACHE_SIZE) + self.active = True # always active + self._exclude_attributes_by_domain = exclude_attributes_by_domain + self._entity_sources = entity_sources(recorder.hass) + + def serialize_from_event(self, event: Event) -> bytes | None: + """Serialize event data.""" + try: + return StateAttributes.shared_attrs_bytes_from_event( + event, + self._entity_sources, + self._exclude_attributes_by_domain, + self.recorder.dialect_name, + ) + except JSON_ENCODE_EXCEPTIONS as ex: + _LOGGER.warning( + "State is not JSON serializable: %s: %s", + event.data.get("new_state"), + ex, + ) + return None + + def load(self, events: list[Event], session: Session) -> None: + """Load the shared_attrs to attributes_ids mapping into memory from events. + + This call is not thread-safe and must be called from the + recorder thread. + """ + if hashes := { + StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes) + for event in events + if (shared_attrs_bytes := self.serialize_from_event(event)) + }: + self._load_from_hashes(hashes, session) + + def get(self, shared_attr: str, data_hash: int, session: Session) -> int | None: + """Resolve shared_attrs to the attributes_id. + + This call is not thread-safe and must be called from the + recorder thread. + """ + return self.get_many(((shared_attr, data_hash),), session)[shared_attr] + + def get_many( + self, shared_attrs_data_hashes: Iterable[tuple[str, int]], session: Session + ) -> dict[str, int | None]: + """Resolve shared_attrs to attributes_ids. + + This call is not thread-safe and must be called from the + recorder thread. + """ + results: dict[str, int | None] = {} + missing_hashes: set[int] = set() + for shared_attrs, data_hash in shared_attrs_data_hashes: + if (attributes_id := self._id_map.get(shared_attrs)) is None: + missing_hashes.add(data_hash) + + results[shared_attrs] = attributes_id + + if not missing_hashes: + return results + + return results | self._load_from_hashes(missing_hashes, session) + + def _load_from_hashes( + self, hashes: Iterable[int], session: Session + ) -> dict[str, int | None]: + """Load the shared_attrs to attributes_ids mapping into memory from a list of hashes. + + This call is not thread-safe and must be called from the + recorder thread. + """ + results: dict[str, int | None] = {} + with session.no_autoflush: + for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS): + for attributes_id, shared_attrs in session.execute( + get_shared_attributes(hashs_chunk) + ): + results[shared_attrs] = self._id_map[shared_attrs] = cast( + int, attributes_id + ) + + return results + + def add_pending(self, db_state_attributes: StateAttributes) -> None: + """Add a pending StateAttributes that will be committed at the next interval. + + This call is not thread-safe and must be called from the + recorder thread. + """ + assert db_state_attributes.shared_attrs is not None + shared_attrs: str = db_state_attributes.shared_attrs + self._pending[shared_attrs] = db_state_attributes + + def post_commit_pending(self) -> None: + """Call after commit to load the attributes_ids of the new StateAttributes into the LRU. + + This call is not thread-safe and must be called from the + recorder thread. + """ + for shared_attrs, db_state_attributes in self._pending.items(): + self._id_map[shared_attrs] = db_state_attributes.attributes_id + self._pending.clear() + + def evict_purged(self, attributes_ids: set[int]) -> None: + """Evict purged attributes_ids from the cache when they are no longer used. + + This call is not thread-safe and must be called from the + recorder thread. + """ + id_map = self._id_map + state_attributes_ids_reversed = { + attributes_id: shared_attrs + for shared_attrs, attributes_id in id_map.items() + } + # Evict any purged data from the cache + for purged_attributes_id in attributes_ids.intersection( + state_attributes_ids_reversed + ): + id_map.pop(state_attributes_ids_reversed[purged_attributes_id], None) diff --git a/homeassistant/components/recorder/table_managers/states_meta.py b/homeassistant/components/recorder/table_managers/states_meta.py index ded1690df13..b8b763aae33 100644 --- a/homeassistant/components/recorder/table_managers/states_meta.py +++ b/homeassistant/components/recorder/table_managers/states_meta.py @@ -4,12 +4,11 @@ from __future__ import annotations from collections.abc import Iterable from typing import TYPE_CHECKING, cast -from lru import LRU # pylint: disable=no-name-in-module from sqlalchemy.orm.session import Session from homeassistant.core import Event -from . import BaseTableManager +from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import StatesMeta from ..queries import find_all_states_metadata_ids, find_states_metadata_ids @@ -21,15 +20,13 @@ if TYPE_CHECKING: CACHE_SIZE = 8192 -class StatesMetaManager(BaseTableManager): +class StatesMetaManager(BaseLRUTableManager[StatesMeta]): """Manage the StatesMeta table.""" def __init__(self, recorder: Recorder) -> None: """Initialize the states meta manager.""" - self._id_map: dict[str, int] = LRU(CACHE_SIZE) - self._pending: dict[str, StatesMeta] = {} self._did_first_load = False - super().__init__(recorder) + super().__init__(recorder, CACHE_SIZE) def load(self, events: list[Event], session: Session) -> None: """Load the entity_id to metadata_id mapping into memory. @@ -112,14 +109,6 @@ class StatesMetaManager(BaseTableManager): return results - def get_pending(self, entity_id: str) -> StatesMeta | None: - """Get pending StatesMeta that have not be assigned ids yet. - - This call is not thread-safe and must be called from the - recorder thread. - """ - return self._pending.get(entity_id) - def add_pending(self, db_states_meta: StatesMeta) -> None: """Add a pending StatesMeta that will be committed at the next interval. @@ -140,15 +129,6 @@ class StatesMetaManager(BaseTableManager): self._id_map[entity_id] = db_states_meta.metadata_id self._pending.clear() - def reset(self) -> None: - """Reset the states meta manager after the database has been reset or changed. - - This call is not thread-safe and must be called from the - recorder thread. - """ - self._id_map.clear() - self._pending.clear() - def evict_purged(self, entity_ids: Iterable[str]) -> None: """Evict purged event_types from the cache when they are no longer used. diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 5355931a76a..ed804087d8a 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -1872,9 +1872,11 @@ def test_deduplication_event_data_inside_commit_interval( assert all(event.data_id == first_data_id for event in events) -# Patch STATE_ATTRIBUTES_ID_CACHE_SIZE since otherwise +# Patch CACHE_SIZE since otherwise # the CI can fail because the test takes too long to run -@patch("homeassistant.components.recorder.core.STATE_ATTRIBUTES_ID_CACHE_SIZE", 5) +@patch( + "homeassistant.components.recorder.table_managers.state_attributes.CACHE_SIZE", 5 +) def test_deduplication_state_attributes_inside_commit_interval( hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture ) -> None: @@ -2159,4 +2161,8 @@ async def test_lru_increases_with_many_entities( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10)) await async_wait_recording_done(hass) - assert recorder_mock._state_attributes_ids.get_size() == mock_entity_count * 2 + assert ( + recorder_mock.state_attributes_manager._id_map.get_size() + == mock_entity_count * 2 + ) + assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2