Migrate StateAttributes to use a table manager (#89760)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2023-03-15 15:26:29 -10:00 committed by GitHub
parent ccab45520b
commit e379aa23bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 274 additions and 227 deletions

View file

@ -11,10 +11,9 @@ import queue
import sqlite3
import threading
import time
from typing import Any, TypeVar, cast
from typing import Any, TypeVar
import async_timeout
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
@ -30,7 +29,6 @@ from homeassistant.const import (
MATCH_ALL,
)
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.entity import entity_sources
from homeassistant.helpers.event import (
async_track_time_change,
async_track_time_interval,
@ -40,7 +38,6 @@ from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
from homeassistant.util.enum import try_parse_enum
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from . import migration, statistics
from .const import (
@ -52,7 +49,6 @@ from .const import (
MAX_QUEUE_BACKLOG,
MYSQLDB_PYMYSQL_URL_PREFIX,
MYSQLDB_URL_PREFIX,
SQLITE_MAX_BIND_VARS,
SQLITE_URL_PREFIX,
SupportedDialect,
)
@ -79,8 +75,6 @@ from .models import (
)
from .pool import POOL_SIZE, MutexPool, RecorderPool
from .queries import (
find_shared_attributes_id,
get_shared_attributes,
has_entity_ids_to_migrate,
has_event_type_to_migrate,
has_events_context_ids_to_migrate,
@ -89,6 +83,7 @@ from .queries import (
from .run_history import RunHistory
from .table_managers.event_data import EventDataManager
from .table_managers.event_types import EventTypeManager
from .table_managers.state_attributes import StateAttributesManager
from .table_managers.states_meta import StatesMetaManager
from .tasks import (
AdjustLRUSizeTask,
@ -115,7 +110,6 @@ from .tasks import (
)
from .util import (
build_mysqldb_conv,
chunked,
dburl_to_path,
end_incomplete_runs,
is_second_sunday,
@ -136,15 +130,6 @@ DEFAULT_URL = "sqlite:///{hass_config_path}"
# States and Events objects
EXPIRE_AFTER_COMMITS = 120
# The number of attribute ids to cache in memory
#
# Based on:
# - The number of overlapping attributes
# - How frequently states with overlapping attributes will change
# - How much memory our low end hardware has
STATE_ATTRIBUTES_ID_CACHE_SIZE = 2048
SHUTDOWN_TASK = object()
COMMIT_TASK = CommitTask()
@ -206,7 +191,6 @@ class Recorder(threading.Thread):
self._queue_watch = threading.Event()
self.engine: Engine | None = None
self.run_history = RunHistory()
self._entity_sources = entity_sources(hass)
# The entity_filter is exposed on the recorder instance so that
# it can be used to see if an entity is being recorded and is called
@ -217,11 +201,12 @@ class Recorder(threading.Thread):
self.schema_version = 0
self._commits_without_expire = 0
self._old_states: dict[str | None, States] = {}
self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE)
self.event_data_manager = EventDataManager(self)
self.event_type_manager = EventTypeManager(self)
self.states_meta_manager = StatesMetaManager(self)
self._pending_state_attributes: dict[str, StateAttributes] = {}
self.state_attributes_manager = StateAttributesManager(
self, exclude_attributes_by_domain
)
self._pending_expunge: list[States] = []
self.event_session: Session | None = None
self._get_session: Callable[[], Session] | None = None
@ -231,7 +216,6 @@ class Recorder(threading.Thread):
self.migration_is_live = False
self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self._exclude_attributes_by_domain = exclude_attributes_by_domain
self._event_listener: CALLBACK_TYPE | None = None
self._queue_watcher: CALLBACK_TYPE | None = None
@ -507,11 +491,9 @@ class Recorder(threading.Thread):
If the number of entities has increased, increase the size of the LRU
cache to avoid thrashing.
"""
state_attributes_lru = self._state_attributes_ids
current_size = state_attributes_lru.get_size()
new_size = self.hass.states.async_entity_ids_count() * 2
if new_size > current_size:
state_attributes_lru.set_size(new_size)
self.state_attributes_manager.adjust_lru_size(new_size)
self.states_meta_manager.adjust_lru_size(new_size)
@callback
def async_periodic_statistics(self) -> None:
@ -776,33 +758,10 @@ class Recorder(threading.Thread):
non_state_change_events.append(event_)
assert self.event_session is not None
self._pre_process_state_change_events(state_change_events)
self.event_data_manager.load(non_state_change_events, self.event_session)
self.event_type_manager.load(non_state_change_events, self.event_session)
self.states_meta_manager.load(state_change_events, self.event_session)
def _pre_process_state_change_events(self, events: list[Event]) -> None:
"""Load startup state attributes from the database.
Since the _state_attributes_ids cache is empty at startup
we restore it from the database to avoid having to look up
the attributes in the database for every state change
until its primed.
"""
assert self.event_session is not None
if hashes := {
StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
for event in events
if (
shared_attrs_bytes := self._serialize_state_attributes_from_event(event)
)
}:
with self.event_session.no_autoflush:
for hash_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for id_, shared_attrs in self.event_session.execute(
get_shared_attributes(hash_chunk)
).fetchall():
self._state_attributes_ids[shared_attrs] = id_
self.state_attributes_manager.load(state_change_events, self.event_session)
def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None:
"""Process a task, guarding against exceptions to ensure the loop does not collapse."""
@ -932,24 +891,6 @@ class Recorder(threading.Thread):
if not self.commit_interval:
self._commit_event_session_or_retry()
def _find_shared_attr_in_db(self, attr_hash: int, shared_attrs: str) -> int | None:
"""Find shared attributes in the db from the hash and shared_attrs."""
#
# Avoid the event session being flushed since it will
# commit all the pending events and states to the database.
#
# The lookup has already have checked to see if the data is cached
# or going to be written in the next commit so there is no
# need to flush before checking the database.
#
assert self.event_session is not None
with self.event_session.no_autoflush:
if attributes_id := self.event_session.execute(
find_shared_attributes_id(attr_hash, shared_attrs)
).first():
return cast(int, attributes_id[0])
return None
def _process_non_state_changed_event_into_session(self, event: Event) -> None:
"""Process any event into the session except state changed."""
session = self.event_session
@ -996,67 +937,53 @@ class Recorder(threading.Thread):
session.add(dbevent)
def _serialize_state_attributes_from_event(self, event: Event) -> bytes | None:
"""Serialize state changed event data."""
try:
return StateAttributes.shared_attrs_bytes_from_event(
event,
self._entity_sources,
self._exclude_attributes_by_domain,
self.dialect_name,
)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
event.data.get("new_state"),
ex,
)
return None
def _process_state_changed_event_into_session(self, event: Event) -> None:
"""Process a state_changed event into the session."""
state_attributes_manager = self.state_attributes_manager
dbstate = States.from_event(event)
if (entity_id := dbstate.entity_id) is None or not (
shared_attrs_bytes := self._serialize_state_attributes_from_event(event)
shared_attrs_bytes := state_attributes_manager.serialize_from_event(event)
):
return
assert self.event_session is not None
event_session = self.event_session
session = self.event_session
# Map the entity_id to the StatesMeta table
states_meta_manager = self.states_meta_manager
if pending_states_meta := states_meta_manager.get_pending(entity_id):
dbstate.states_meta_rel = pending_states_meta
elif metadata_id := states_meta_manager.get(entity_id, event_session, True):
elif metadata_id := states_meta_manager.get(entity_id, session, True):
dbstate.metadata_id = metadata_id
else:
states_meta = StatesMeta(entity_id=entity_id)
states_meta_manager.add_pending(states_meta)
event_session.add(states_meta)
session.add(states_meta)
dbstate.states_meta_rel = states_meta
# Map the event data to the StateAttributes table
shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None
# Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
dbstate.state_attributes = pending_attributes
if pending_event_data := state_attributes_manager.get_pending(shared_attrs):
dbstate.state_attributes = pending_event_data
# Matching attributes id found in the cache
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
elif (
attributes_id := state_attributes_manager.get_from_cache(shared_attrs)
) or (
(hash_ := StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes))
and (
attributes_id := state_attributes_manager.get(
shared_attrs, hash_, session
)
)
):
dbstate.attributes_id = attributes_id
else:
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
# Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id
self._state_attributes_ids[shared_attrs] = attributes_id
# No matching attributes found, save them in the DB
else:
dbstate_attributes = StateAttributes(
shared_attrs=shared_attrs, hash=attr_hash
)
dbstate.state_attributes = dbstate_attributes
self._pending_state_attributes[shared_attrs] = dbstate_attributes
self.event_session.add(dbstate_attributes)
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
state_attributes_manager.add_pending(dbstate_attributes)
session.add(dbstate_attributes)
dbstate.state_attributes = dbstate_attributes
if old_state := self._old_states.pop(entity_id, None):
if old_state.state_id:
@ -1128,11 +1055,7 @@ class Recorder(threading.Thread):
# and we now know the attributes_ids. We can save
# many selects for matching attributes by loading them
# into the LRU cache now.
for state_attr in self._pending_state_attributes.values():
self._state_attributes_ids[
state_attr.shared_attrs
] = state_attr.attributes_id
self._pending_state_attributes = {}
self.state_attributes_manager.post_commit_pending()
self.event_data_manager.post_commit_pending()
self.event_type_manager.post_commit_pending()
self.states_meta_manager.post_commit_pending()
@ -1158,8 +1081,7 @@ class Recorder(threading.Thread):
def _close_event_session(self) -> None:
"""Close the event session."""
self._old_states.clear()
self._state_attributes_ids.clear()
self._pending_state_attributes.clear()
self.state_attributes_manager.reset()
self.event_data_manager.reset()
self.event_type_manager.reset()
self.states_meta_manager.reset()

View file

@ -479,28 +479,6 @@ def _evict_purged_states_from_old_states_cache(
old_states.pop(old_state_reversed[purged_state_id], None)
def _evict_purged_attributes_from_attributes_cache(
instance: Recorder, purged_attributes_ids: set[int]
) -> None:
"""Evict purged attribute ids from the attribute ids cache."""
# Make a map from attributes_id to the attributes json
state_attributes_ids = (
instance._state_attributes_ids # pylint: disable=protected-access
)
state_attributes_ids_reversed = {
attributes_id: attributes
for attributes, attributes_id in state_attributes_ids.items()
}
# Evict any purged attributes from the state_attributes_ids cache
for purged_attribute_id in purged_attributes_ids.intersection(
state_attributes_ids_reversed
):
state_attributes_ids.pop(
state_attributes_ids_reversed[purged_attribute_id], None
)
def _purge_batch_attributes_ids(
instance: Recorder, session: Session, attributes_ids: set[int]
) -> None:
@ -512,7 +490,7 @@ def _purge_batch_attributes_ids(
_LOGGER.debug("Deleted %s attribute states", deleted_rows)
# Evict any entries in the state_attributes_ids cache referring to a purged state
_evict_purged_attributes_from_attributes_cache(instance, attributes_ids)
instance.state_attributes_manager.evict_purged(attributes_ids)
def _purge_batch_data_ids(

View file

@ -74,17 +74,6 @@ def find_states_metadata_ids(entity_ids: Iterable[str]) -> StatementLambdaElemen
)
def find_shared_attributes_id(
data_hash: int, shared_attrs: str
) -> StatementLambdaElement:
"""Find an attributes_id by hash and shared_attrs."""
return lambda_stmt(
lambda: select(StateAttributes.attributes_id)
.filter(StateAttributes.hash == data_hash)
.filter(StateAttributes.shared_attrs == shared_attrs)
)
def _state_attrs_exist(attr: int | None) -> Select:
"""Check if a state attributes id exists in the states table."""
# https://github.com/sqlalchemy/sqlalchemy/issues/9189

View file

@ -1,15 +1,75 @@
"""Managers for each table."""
from typing import TYPE_CHECKING
from collections.abc import MutableMapping
from typing import TYPE_CHECKING, Generic, TypeVar
from lru import LRU # pylint: disable=no-name-in-module
if TYPE_CHECKING:
from ..core import Recorder
_DataT = TypeVar("_DataT")
class BaseTableManager:
class BaseTableManager(Generic[_DataT]):
"""Base class for table managers."""
def __init__(self, recorder: "Recorder") -> None:
"""Initialize the table manager."""
"""Initialize the table manager.
The table manager is responsible for managing the id mappings
for a table. When data is committed to the database, the
manager will move the data from the pending to the id map.
"""
self.active = False
self.recorder = recorder
self._pending: dict[str, _DataT] = {}
self._id_map: MutableMapping[str, int] = {}
def get_from_cache(self, data: str) -> int | None:
"""Resolve data to the id without accessing the underlying database.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._id_map.get(data)
def get_pending(self, shared_data: str) -> _DataT | None:
"""Get pending data that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(shared_data)
def reset(self) -> None:
"""Reset after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
class BaseLRUTableManager(BaseTableManager[_DataT]):
"""Base class for LRU table managers."""
def __init__(self, recorder: "Recorder", lru_size: int) -> None:
"""Initialize the LRU table manager.
We keep track of the most recently used items
and evict the least recently used items when the cache is full.
"""
super().__init__(recorder)
self._id_map: MutableMapping[str, int] = LRU(lru_size)
def adjust_lru_size(self, new_size: int) -> None:
"""Adjust the LRU cache size.
This call is not thread-safe and must be called from the
recorder thread.
"""
lru: LRU = self._id_map
if new_size > lru.get_size():
lru.set_size(new_size)

View file

@ -5,13 +5,12 @@ from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy.orm.session import Session
from homeassistant.core import Event
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from . import BaseTableManager
from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventData
from ..queries import get_shared_event_datas
@ -26,14 +25,12 @@ CACHE_SIZE = 2048
_LOGGER = logging.getLogger(__name__)
class EventDataManager(BaseTableManager):
class EventDataManager(BaseLRUTableManager[EventData]):
"""Manage the EventData table."""
def __init__(self, recorder: Recorder) -> None:
"""Initialize the event type manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
self._pending: dict[str, EventData] = {}
super().__init__(recorder)
super().__init__(recorder, CACHE_SIZE)
self.active = True # always active
def serialize_from_event(self, event: Event) -> bytes | None:
@ -67,14 +64,6 @@ class EventDataManager(BaseTableManager):
"""
return self.get_many(((shared_data, data_hash),), session)[shared_data]
def get_from_cache(self, shared_data: str) -> int | None:
"""Resolve shared_data to the data_id without accessing the underlying database.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._id_map.get(shared_data)
def get_many(
self, shared_data_data_hashs: Iterable[tuple[str, int]], session: Session
) -> dict[str, int | None]:
@ -116,14 +105,6 @@ class EventDataManager(BaseTableManager):
return results
def get_pending(self, shared_data: str) -> EventData | None:
"""Get pending EventData that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(shared_data)
def add_pending(self, db_event_data: EventData) -> None:
"""Add a pending EventData that will be committed at the next interval.
@ -144,15 +125,6 @@ class EventDataManager(BaseTableManager):
self._id_map[shared_data] = db_event_data.data_id
self._pending.clear()
def reset(self) -> None:
"""Reset the event manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, data_ids: set[int]) -> None:
"""Evict purged data_ids from the cache when they are no longer used.

View file

@ -4,12 +4,11 @@ from __future__ import annotations
from collections.abc import Iterable
from typing import TYPE_CHECKING, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy.orm.session import Session
from homeassistant.core import Event
from . import BaseTableManager
from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventTypes
from ..queries import find_event_type_ids
@ -22,14 +21,12 @@ if TYPE_CHECKING:
CACHE_SIZE = 2048
class EventTypeManager(BaseTableManager):
class EventTypeManager(BaseLRUTableManager[EventTypes]):
"""Manage the EventTypes table."""
def __init__(self, recorder: Recorder) -> None:
"""Initialize the event type manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
self._pending: dict[str, EventTypes] = {}
super().__init__(recorder)
super().__init__(recorder, CACHE_SIZE)
def load(self, events: list[Event], session: Session) -> None:
"""Load the event_type to event_type_ids mapping into memory.
@ -80,14 +77,6 @@ class EventTypeManager(BaseTableManager):
return results
def get_pending(self, event_type: str) -> EventTypes | None:
"""Get pending EventTypes that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(event_type)
def add_pending(self, db_event_type: EventTypes) -> None:
"""Add a pending EventTypes that will be committed at the next interval.
@ -108,15 +97,6 @@ class EventTypeManager(BaseTableManager):
self._id_map[event_type] = db_event_types.event_type_id
self._pending.clear()
def reset(self) -> None:
"""Reset the event manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, event_types: Iterable[str]) -> None:
"""Evict purged event_types from the cache when they are no longer used.

View file

@ -0,0 +1,160 @@
"""Support managing StateAttributes."""
from __future__ import annotations
from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session
from homeassistant.core import Event
from homeassistant.helpers.entity import entity_sources
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StateAttributes
from ..queries import get_shared_attributes
from ..util import chunked
if TYPE_CHECKING:
from ..core import Recorder
# The number of attribute ids to cache in memory
#
# Based on:
# - The number of overlapping attributes
# - How frequently states with overlapping attributes will change
# - How much memory our low end hardware has
CACHE_SIZE = 2048
_LOGGER = logging.getLogger(__name__)
class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
"""Manage the StateAttributes table."""
def __init__(
self, recorder: Recorder, exclude_attributes_by_domain: dict[str, set[str]]
) -> None:
"""Initialize the event type manager."""
super().__init__(recorder, CACHE_SIZE)
self.active = True # always active
self._exclude_attributes_by_domain = exclude_attributes_by_domain
self._entity_sources = entity_sources(recorder.hass)
def serialize_from_event(self, event: Event) -> bytes | None:
"""Serialize event data."""
try:
return StateAttributes.shared_attrs_bytes_from_event(
event,
self._entity_sources,
self._exclude_attributes_by_domain,
self.recorder.dialect_name,
)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
event.data.get("new_state"),
ex,
)
return None
def load(self, events: list[Event], session: Session) -> None:
"""Load the shared_attrs to attributes_ids mapping into memory from events.
This call is not thread-safe and must be called from the
recorder thread.
"""
if hashes := {
StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
for event in events
if (shared_attrs_bytes := self.serialize_from_event(event))
}:
self._load_from_hashes(hashes, session)
def get(self, shared_attr: str, data_hash: int, session: Session) -> int | None:
"""Resolve shared_attrs to the attributes_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self.get_many(((shared_attr, data_hash),), session)[shared_attr]
def get_many(
self, shared_attrs_data_hashes: Iterable[tuple[str, int]], session: Session
) -> dict[str, int | None]:
"""Resolve shared_attrs to attributes_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
missing_hashes: set[int] = set()
for shared_attrs, data_hash in shared_attrs_data_hashes:
if (attributes_id := self._id_map.get(shared_attrs)) is None:
missing_hashes.add(data_hash)
results[shared_attrs] = attributes_id
if not missing_hashes:
return results
return results | self._load_from_hashes(missing_hashes, session)
def _load_from_hashes(
self, hashes: Iterable[int], session: Session
) -> dict[str, int | None]:
"""Load the shared_attrs to attributes_ids mapping into memory from a list of hashes.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for attributes_id, shared_attrs in session.execute(
get_shared_attributes(hashs_chunk)
):
results[shared_attrs] = self._id_map[shared_attrs] = cast(
int, attributes_id
)
return results
def add_pending(self, db_state_attributes: StateAttributes) -> None:
"""Add a pending StateAttributes that will be committed at the next interval.
This call is not thread-safe and must be called from the
recorder thread.
"""
assert db_state_attributes.shared_attrs is not None
shared_attrs: str = db_state_attributes.shared_attrs
self._pending[shared_attrs] = db_state_attributes
def post_commit_pending(self) -> None:
"""Call after commit to load the attributes_ids of the new StateAttributes into the LRU.
This call is not thread-safe and must be called from the
recorder thread.
"""
for shared_attrs, db_state_attributes in self._pending.items():
self._id_map[shared_attrs] = db_state_attributes.attributes_id
self._pending.clear()
def evict_purged(self, attributes_ids: set[int]) -> None:
"""Evict purged attributes_ids from the cache when they are no longer used.
This call is not thread-safe and must be called from the
recorder thread.
"""
id_map = self._id_map
state_attributes_ids_reversed = {
attributes_id: shared_attrs
for shared_attrs, attributes_id in id_map.items()
}
# Evict any purged data from the cache
for purged_attributes_id in attributes_ids.intersection(
state_attributes_ids_reversed
):
id_map.pop(state_attributes_ids_reversed[purged_attributes_id], None)

View file

@ -4,12 +4,11 @@ from __future__ import annotations
from collections.abc import Iterable
from typing import TYPE_CHECKING, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy.orm.session import Session
from homeassistant.core import Event
from . import BaseTableManager
from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StatesMeta
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
@ -21,15 +20,13 @@ if TYPE_CHECKING:
CACHE_SIZE = 8192
class StatesMetaManager(BaseTableManager):
class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
"""Manage the StatesMeta table."""
def __init__(self, recorder: Recorder) -> None:
"""Initialize the states meta manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
self._pending: dict[str, StatesMeta] = {}
self._did_first_load = False
super().__init__(recorder)
super().__init__(recorder, CACHE_SIZE)
def load(self, events: list[Event], session: Session) -> None:
"""Load the entity_id to metadata_id mapping into memory.
@ -112,14 +109,6 @@ class StatesMetaManager(BaseTableManager):
return results
def get_pending(self, entity_id: str) -> StatesMeta | None:
"""Get pending StatesMeta that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(entity_id)
def add_pending(self, db_states_meta: StatesMeta) -> None:
"""Add a pending StatesMeta that will be committed at the next interval.
@ -140,15 +129,6 @@ class StatesMetaManager(BaseTableManager):
self._id_map[entity_id] = db_states_meta.metadata_id
self._pending.clear()
def reset(self) -> None:
"""Reset the states meta manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, entity_ids: Iterable[str]) -> None:
"""Evict purged event_types from the cache when they are no longer used.

View file

@ -1872,9 +1872,11 @@ def test_deduplication_event_data_inside_commit_interval(
assert all(event.data_id == first_data_id for event in events)
# Patch STATE_ATTRIBUTES_ID_CACHE_SIZE since otherwise
# Patch CACHE_SIZE since otherwise
# the CI can fail because the test takes too long to run
@patch("homeassistant.components.recorder.core.STATE_ATTRIBUTES_ID_CACHE_SIZE", 5)
@patch(
"homeassistant.components.recorder.table_managers.state_attributes.CACHE_SIZE", 5
)
def test_deduplication_state_attributes_inside_commit_interval(
hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture
) -> None:
@ -2159,4 +2161,8 @@ async def test_lru_increases_with_many_entities(
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10))
await async_wait_recording_done(hass)
assert recorder_mock._state_attributes_ids.get_size() == mock_entity_count * 2
assert (
recorder_mock.state_attributes_manager._id_map.get_size()
== mock_entity_count * 2
)
assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2