Migrate StateAttributes to use a table manager (#89760)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
ccab45520b
commit
e379aa23bd
9 changed files with 274 additions and 227 deletions
|
@ -11,10 +11,9 @@ import queue
|
|||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import async_timeout
|
||||
from lru import LRU # pylint: disable=no-name-in-module
|
||||
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
@ -30,7 +29,6 @@ from homeassistant.const import (
|
|||
MATCH_ALL,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity import entity_sources
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_time_change,
|
||||
async_track_time_interval,
|
||||
|
@ -40,7 +38,6 @@ from homeassistant.helpers.start import async_at_started
|
|||
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.enum import try_parse_enum
|
||||
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
|
||||
|
||||
from . import migration, statistics
|
||||
from .const import (
|
||||
|
@ -52,7 +49,6 @@ from .const import (
|
|||
MAX_QUEUE_BACKLOG,
|
||||
MYSQLDB_PYMYSQL_URL_PREFIX,
|
||||
MYSQLDB_URL_PREFIX,
|
||||
SQLITE_MAX_BIND_VARS,
|
||||
SQLITE_URL_PREFIX,
|
||||
SupportedDialect,
|
||||
)
|
||||
|
@ -79,8 +75,6 @@ from .models import (
|
|||
)
|
||||
from .pool import POOL_SIZE, MutexPool, RecorderPool
|
||||
from .queries import (
|
||||
find_shared_attributes_id,
|
||||
get_shared_attributes,
|
||||
has_entity_ids_to_migrate,
|
||||
has_event_type_to_migrate,
|
||||
has_events_context_ids_to_migrate,
|
||||
|
@ -89,6 +83,7 @@ from .queries import (
|
|||
from .run_history import RunHistory
|
||||
from .table_managers.event_data import EventDataManager
|
||||
from .table_managers.event_types import EventTypeManager
|
||||
from .table_managers.state_attributes import StateAttributesManager
|
||||
from .table_managers.states_meta import StatesMetaManager
|
||||
from .tasks import (
|
||||
AdjustLRUSizeTask,
|
||||
|
@ -115,7 +110,6 @@ from .tasks import (
|
|||
)
|
||||
from .util import (
|
||||
build_mysqldb_conv,
|
||||
chunked,
|
||||
dburl_to_path,
|
||||
end_incomplete_runs,
|
||||
is_second_sunday,
|
||||
|
@ -136,15 +130,6 @@ DEFAULT_URL = "sqlite:///{hass_config_path}"
|
|||
# States and Events objects
|
||||
EXPIRE_AFTER_COMMITS = 120
|
||||
|
||||
# The number of attribute ids to cache in memory
|
||||
#
|
||||
# Based on:
|
||||
# - The number of overlapping attributes
|
||||
# - How frequently states with overlapping attributes will change
|
||||
# - How much memory our low end hardware has
|
||||
STATE_ATTRIBUTES_ID_CACHE_SIZE = 2048
|
||||
|
||||
|
||||
SHUTDOWN_TASK = object()
|
||||
|
||||
COMMIT_TASK = CommitTask()
|
||||
|
@ -206,7 +191,6 @@ class Recorder(threading.Thread):
|
|||
self._queue_watch = threading.Event()
|
||||
self.engine: Engine | None = None
|
||||
self.run_history = RunHistory()
|
||||
self._entity_sources = entity_sources(hass)
|
||||
|
||||
# The entity_filter is exposed on the recorder instance so that
|
||||
# it can be used to see if an entity is being recorded and is called
|
||||
|
@ -217,11 +201,12 @@ class Recorder(threading.Thread):
|
|||
self.schema_version = 0
|
||||
self._commits_without_expire = 0
|
||||
self._old_states: dict[str | None, States] = {}
|
||||
self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE)
|
||||
self.event_data_manager = EventDataManager(self)
|
||||
self.event_type_manager = EventTypeManager(self)
|
||||
self.states_meta_manager = StatesMetaManager(self)
|
||||
self._pending_state_attributes: dict[str, StateAttributes] = {}
|
||||
self.state_attributes_manager = StateAttributesManager(
|
||||
self, exclude_attributes_by_domain
|
||||
)
|
||||
self._pending_expunge: list[States] = []
|
||||
self.event_session: Session | None = None
|
||||
self._get_session: Callable[[], Session] | None = None
|
||||
|
@ -231,7 +216,6 @@ class Recorder(threading.Thread):
|
|||
self.migration_is_live = False
|
||||
self._database_lock_task: DatabaseLockTask | None = None
|
||||
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
|
||||
self._exclude_attributes_by_domain = exclude_attributes_by_domain
|
||||
|
||||
self._event_listener: CALLBACK_TYPE | None = None
|
||||
self._queue_watcher: CALLBACK_TYPE | None = None
|
||||
|
@ -507,11 +491,9 @@ class Recorder(threading.Thread):
|
|||
If the number of entities has increased, increase the size of the LRU
|
||||
cache to avoid thrashing.
|
||||
"""
|
||||
state_attributes_lru = self._state_attributes_ids
|
||||
current_size = state_attributes_lru.get_size()
|
||||
new_size = self.hass.states.async_entity_ids_count() * 2
|
||||
if new_size > current_size:
|
||||
state_attributes_lru.set_size(new_size)
|
||||
self.state_attributes_manager.adjust_lru_size(new_size)
|
||||
self.states_meta_manager.adjust_lru_size(new_size)
|
||||
|
||||
@callback
|
||||
def async_periodic_statistics(self) -> None:
|
||||
|
@ -776,33 +758,10 @@ class Recorder(threading.Thread):
|
|||
non_state_change_events.append(event_)
|
||||
|
||||
assert self.event_session is not None
|
||||
self._pre_process_state_change_events(state_change_events)
|
||||
self.event_data_manager.load(non_state_change_events, self.event_session)
|
||||
self.event_type_manager.load(non_state_change_events, self.event_session)
|
||||
self.states_meta_manager.load(state_change_events, self.event_session)
|
||||
|
||||
def _pre_process_state_change_events(self, events: list[Event]) -> None:
|
||||
"""Load startup state attributes from the database.
|
||||
|
||||
Since the _state_attributes_ids cache is empty at startup
|
||||
we restore it from the database to avoid having to look up
|
||||
the attributes in the database for every state change
|
||||
until its primed.
|
||||
"""
|
||||
assert self.event_session is not None
|
||||
if hashes := {
|
||||
StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
|
||||
for event in events
|
||||
if (
|
||||
shared_attrs_bytes := self._serialize_state_attributes_from_event(event)
|
||||
)
|
||||
}:
|
||||
with self.event_session.no_autoflush:
|
||||
for hash_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
|
||||
for id_, shared_attrs in self.event_session.execute(
|
||||
get_shared_attributes(hash_chunk)
|
||||
).fetchall():
|
||||
self._state_attributes_ids[shared_attrs] = id_
|
||||
self.state_attributes_manager.load(state_change_events, self.event_session)
|
||||
|
||||
def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None:
|
||||
"""Process a task, guarding against exceptions to ensure the loop does not collapse."""
|
||||
|
@ -932,24 +891,6 @@ class Recorder(threading.Thread):
|
|||
if not self.commit_interval:
|
||||
self._commit_event_session_or_retry()
|
||||
|
||||
def _find_shared_attr_in_db(self, attr_hash: int, shared_attrs: str) -> int | None:
|
||||
"""Find shared attributes in the db from the hash and shared_attrs."""
|
||||
#
|
||||
# Avoid the event session being flushed since it will
|
||||
# commit all the pending events and states to the database.
|
||||
#
|
||||
# The lookup has already have checked to see if the data is cached
|
||||
# or going to be written in the next commit so there is no
|
||||
# need to flush before checking the database.
|
||||
#
|
||||
assert self.event_session is not None
|
||||
with self.event_session.no_autoflush:
|
||||
if attributes_id := self.event_session.execute(
|
||||
find_shared_attributes_id(attr_hash, shared_attrs)
|
||||
).first():
|
||||
return cast(int, attributes_id[0])
|
||||
return None
|
||||
|
||||
def _process_non_state_changed_event_into_session(self, event: Event) -> None:
|
||||
"""Process any event into the session except state changed."""
|
||||
session = self.event_session
|
||||
|
@ -996,67 +937,53 @@ class Recorder(threading.Thread):
|
|||
|
||||
session.add(dbevent)
|
||||
|
||||
def _serialize_state_attributes_from_event(self, event: Event) -> bytes | None:
|
||||
"""Serialize state changed event data."""
|
||||
try:
|
||||
return StateAttributes.shared_attrs_bytes_from_event(
|
||||
event,
|
||||
self._entity_sources,
|
||||
self._exclude_attributes_by_domain,
|
||||
self.dialect_name,
|
||||
)
|
||||
except JSON_ENCODE_EXCEPTIONS as ex:
|
||||
_LOGGER.warning(
|
||||
"State is not JSON serializable: %s: %s",
|
||||
event.data.get("new_state"),
|
||||
ex,
|
||||
)
|
||||
return None
|
||||
|
||||
def _process_state_changed_event_into_session(self, event: Event) -> None:
|
||||
"""Process a state_changed event into the session."""
|
||||
state_attributes_manager = self.state_attributes_manager
|
||||
dbstate = States.from_event(event)
|
||||
if (entity_id := dbstate.entity_id) is None or not (
|
||||
shared_attrs_bytes := self._serialize_state_attributes_from_event(event)
|
||||
shared_attrs_bytes := state_attributes_manager.serialize_from_event(event)
|
||||
):
|
||||
return
|
||||
|
||||
assert self.event_session is not None
|
||||
event_session = self.event_session
|
||||
session = self.event_session
|
||||
# Map the entity_id to the StatesMeta table
|
||||
states_meta_manager = self.states_meta_manager
|
||||
if pending_states_meta := states_meta_manager.get_pending(entity_id):
|
||||
dbstate.states_meta_rel = pending_states_meta
|
||||
elif metadata_id := states_meta_manager.get(entity_id, event_session, True):
|
||||
elif metadata_id := states_meta_manager.get(entity_id, session, True):
|
||||
dbstate.metadata_id = metadata_id
|
||||
else:
|
||||
states_meta = StatesMeta(entity_id=entity_id)
|
||||
states_meta_manager.add_pending(states_meta)
|
||||
event_session.add(states_meta)
|
||||
session.add(states_meta)
|
||||
dbstate.states_meta_rel = states_meta
|
||||
|
||||
# Map the event data to the StateAttributes table
|
||||
shared_attrs = shared_attrs_bytes.decode("utf-8")
|
||||
dbstate.attributes = None
|
||||
# Matching attributes found in the pending commit
|
||||
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
|
||||
dbstate.state_attributes = pending_attributes
|
||||
if pending_event_data := state_attributes_manager.get_pending(shared_attrs):
|
||||
dbstate.state_attributes = pending_event_data
|
||||
# Matching attributes id found in the cache
|
||||
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
|
||||
elif (
|
||||
attributes_id := state_attributes_manager.get_from_cache(shared_attrs)
|
||||
) or (
|
||||
(hash_ := StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes))
|
||||
and (
|
||||
attributes_id := state_attributes_manager.get(
|
||||
shared_attrs, hash_, session
|
||||
)
|
||||
)
|
||||
):
|
||||
dbstate.attributes_id = attributes_id
|
||||
else:
|
||||
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
|
||||
# Matching attributes found in the database
|
||||
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
|
||||
dbstate.attributes_id = attributes_id
|
||||
self._state_attributes_ids[shared_attrs] = attributes_id
|
||||
# No matching attributes found, save them in the DB
|
||||
else:
|
||||
dbstate_attributes = StateAttributes(
|
||||
shared_attrs=shared_attrs, hash=attr_hash
|
||||
)
|
||||
dbstate.state_attributes = dbstate_attributes
|
||||
self._pending_state_attributes[shared_attrs] = dbstate_attributes
|
||||
self.event_session.add(dbstate_attributes)
|
||||
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
|
||||
state_attributes_manager.add_pending(dbstate_attributes)
|
||||
session.add(dbstate_attributes)
|
||||
dbstate.state_attributes = dbstate_attributes
|
||||
|
||||
if old_state := self._old_states.pop(entity_id, None):
|
||||
if old_state.state_id:
|
||||
|
@ -1128,11 +1055,7 @@ class Recorder(threading.Thread):
|
|||
# and we now know the attributes_ids. We can save
|
||||
# many selects for matching attributes by loading them
|
||||
# into the LRU cache now.
|
||||
for state_attr in self._pending_state_attributes.values():
|
||||
self._state_attributes_ids[
|
||||
state_attr.shared_attrs
|
||||
] = state_attr.attributes_id
|
||||
self._pending_state_attributes = {}
|
||||
self.state_attributes_manager.post_commit_pending()
|
||||
self.event_data_manager.post_commit_pending()
|
||||
self.event_type_manager.post_commit_pending()
|
||||
self.states_meta_manager.post_commit_pending()
|
||||
|
@ -1158,8 +1081,7 @@ class Recorder(threading.Thread):
|
|||
def _close_event_session(self) -> None:
|
||||
"""Close the event session."""
|
||||
self._old_states.clear()
|
||||
self._state_attributes_ids.clear()
|
||||
self._pending_state_attributes.clear()
|
||||
self.state_attributes_manager.reset()
|
||||
self.event_data_manager.reset()
|
||||
self.event_type_manager.reset()
|
||||
self.states_meta_manager.reset()
|
||||
|
|
|
@ -479,28 +479,6 @@ def _evict_purged_states_from_old_states_cache(
|
|||
old_states.pop(old_state_reversed[purged_state_id], None)
|
||||
|
||||
|
||||
def _evict_purged_attributes_from_attributes_cache(
|
||||
instance: Recorder, purged_attributes_ids: set[int]
|
||||
) -> None:
|
||||
"""Evict purged attribute ids from the attribute ids cache."""
|
||||
# Make a map from attributes_id to the attributes json
|
||||
state_attributes_ids = (
|
||||
instance._state_attributes_ids # pylint: disable=protected-access
|
||||
)
|
||||
state_attributes_ids_reversed = {
|
||||
attributes_id: attributes
|
||||
for attributes, attributes_id in state_attributes_ids.items()
|
||||
}
|
||||
|
||||
# Evict any purged attributes from the state_attributes_ids cache
|
||||
for purged_attribute_id in purged_attributes_ids.intersection(
|
||||
state_attributes_ids_reversed
|
||||
):
|
||||
state_attributes_ids.pop(
|
||||
state_attributes_ids_reversed[purged_attribute_id], None
|
||||
)
|
||||
|
||||
|
||||
def _purge_batch_attributes_ids(
|
||||
instance: Recorder, session: Session, attributes_ids: set[int]
|
||||
) -> None:
|
||||
|
@ -512,7 +490,7 @@ def _purge_batch_attributes_ids(
|
|||
_LOGGER.debug("Deleted %s attribute states", deleted_rows)
|
||||
|
||||
# Evict any entries in the state_attributes_ids cache referring to a purged state
|
||||
_evict_purged_attributes_from_attributes_cache(instance, attributes_ids)
|
||||
instance.state_attributes_manager.evict_purged(attributes_ids)
|
||||
|
||||
|
||||
def _purge_batch_data_ids(
|
||||
|
|
|
@ -74,17 +74,6 @@ def find_states_metadata_ids(entity_ids: Iterable[str]) -> StatementLambdaElemen
|
|||
)
|
||||
|
||||
|
||||
def find_shared_attributes_id(
|
||||
data_hash: int, shared_attrs: str
|
||||
) -> StatementLambdaElement:
|
||||
"""Find an attributes_id by hash and shared_attrs."""
|
||||
return lambda_stmt(
|
||||
lambda: select(StateAttributes.attributes_id)
|
||||
.filter(StateAttributes.hash == data_hash)
|
||||
.filter(StateAttributes.shared_attrs == shared_attrs)
|
||||
)
|
||||
|
||||
|
||||
def _state_attrs_exist(attr: int | None) -> Select:
|
||||
"""Check if a state attributes id exists in the states table."""
|
||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||
|
|
|
@ -1,15 +1,75 @@
|
|||
"""Managers for each table."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import MutableMapping
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from lru import LRU # pylint: disable=no-name-in-module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core import Recorder
|
||||
|
||||
_DataT = TypeVar("_DataT")
|
||||
|
||||
class BaseTableManager:
|
||||
|
||||
class BaseTableManager(Generic[_DataT]):
|
||||
"""Base class for table managers."""
|
||||
|
||||
def __init__(self, recorder: "Recorder") -> None:
|
||||
"""Initialize the table manager."""
|
||||
"""Initialize the table manager.
|
||||
|
||||
The table manager is responsible for managing the id mappings
|
||||
for a table. When data is committed to the database, the
|
||||
manager will move the data from the pending to the id map.
|
||||
"""
|
||||
self.active = False
|
||||
self.recorder = recorder
|
||||
self._pending: dict[str, _DataT] = {}
|
||||
self._id_map: MutableMapping[str, int] = {}
|
||||
|
||||
def get_from_cache(self, data: str) -> int | None:
|
||||
"""Resolve data to the id without accessing the underlying database.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self._id_map.get(data)
|
||||
|
||||
def get_pending(self, shared_data: str) -> _DataT | None:
|
||||
"""Get pending data that have not be assigned ids yet.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self._pending.get(shared_data)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset after the database has been reset or changed.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
self._id_map.clear()
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
class BaseLRUTableManager(BaseTableManager[_DataT]):
|
||||
"""Base class for LRU table managers."""
|
||||
|
||||
def __init__(self, recorder: "Recorder", lru_size: int) -> None:
|
||||
"""Initialize the LRU table manager.
|
||||
|
||||
We keep track of the most recently used items
|
||||
and evict the least recently used items when the cache is full.
|
||||
"""
|
||||
super().__init__(recorder)
|
||||
self._id_map: MutableMapping[str, int] = LRU(lru_size)
|
||||
|
||||
def adjust_lru_size(self, new_size: int) -> None:
|
||||
"""Adjust the LRU cache size.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
lru: LRU = self._id_map
|
||||
if new_size > lru.get_size():
|
||||
lru.set_size(new_size)
|
||||
|
|
|
@ -5,13 +5,12 @@ from collections.abc import Iterable
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from lru import LRU # pylint: disable=no-name-in-module
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.core import Event
|
||||
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
|
||||
|
||||
from . import BaseTableManager
|
||||
from . import BaseLRUTableManager
|
||||
from ..const import SQLITE_MAX_BIND_VARS
|
||||
from ..db_schema import EventData
|
||||
from ..queries import get_shared_event_datas
|
||||
|
@ -26,14 +25,12 @@ CACHE_SIZE = 2048
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventDataManager(BaseTableManager):
|
||||
class EventDataManager(BaseLRUTableManager[EventData]):
|
||||
"""Manage the EventData table."""
|
||||
|
||||
def __init__(self, recorder: Recorder) -> None:
|
||||
"""Initialize the event type manager."""
|
||||
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
|
||||
self._pending: dict[str, EventData] = {}
|
||||
super().__init__(recorder)
|
||||
super().__init__(recorder, CACHE_SIZE)
|
||||
self.active = True # always active
|
||||
|
||||
def serialize_from_event(self, event: Event) -> bytes | None:
|
||||
|
@ -67,14 +64,6 @@ class EventDataManager(BaseTableManager):
|
|||
"""
|
||||
return self.get_many(((shared_data, data_hash),), session)[shared_data]
|
||||
|
||||
def get_from_cache(self, shared_data: str) -> int | None:
|
||||
"""Resolve shared_data to the data_id without accessing the underlying database.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self._id_map.get(shared_data)
|
||||
|
||||
def get_many(
|
||||
self, shared_data_data_hashs: Iterable[tuple[str, int]], session: Session
|
||||
) -> dict[str, int | None]:
|
||||
|
@ -116,14 +105,6 @@ class EventDataManager(BaseTableManager):
|
|||
|
||||
return results
|
||||
|
||||
def get_pending(self, shared_data: str) -> EventData | None:
|
||||
"""Get pending EventData that have not be assigned ids yet.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self._pending.get(shared_data)
|
||||
|
||||
def add_pending(self, db_event_data: EventData) -> None:
|
||||
"""Add a pending EventData that will be committed at the next interval.
|
||||
|
||||
|
@ -144,15 +125,6 @@ class EventDataManager(BaseTableManager):
|
|||
self._id_map[shared_data] = db_event_data.data_id
|
||||
self._pending.clear()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the event manager after the database has been reset or changed.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
self._id_map.clear()
|
||||
self._pending.clear()
|
||||
|
||||
def evict_purged(self, data_ids: set[int]) -> None:
|
||||
"""Evict purged data_ids from the cache when they are no longer used.
|
||||
|
||||
|
|
|
@ -4,12 +4,11 @@ from __future__ import annotations
|
|||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from lru import LRU # pylint: disable=no-name-in-module
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.core import Event
|
||||
|
||||
from . import BaseTableManager
|
||||
from . import BaseLRUTableManager
|
||||
from ..const import SQLITE_MAX_BIND_VARS
|
||||
from ..db_schema import EventTypes
|
||||
from ..queries import find_event_type_ids
|
||||
|
@ -22,14 +21,12 @@ if TYPE_CHECKING:
|
|||
CACHE_SIZE = 2048
|
||||
|
||||
|
||||
class EventTypeManager(BaseTableManager):
|
||||
class EventTypeManager(BaseLRUTableManager[EventTypes]):
|
||||
"""Manage the EventTypes table."""
|
||||
|
||||
def __init__(self, recorder: Recorder) -> None:
|
||||
"""Initialize the event type manager."""
|
||||
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
|
||||
self._pending: dict[str, EventTypes] = {}
|
||||
super().__init__(recorder)
|
||||
super().__init__(recorder, CACHE_SIZE)
|
||||
|
||||
def load(self, events: list[Event], session: Session) -> None:
|
||||
"""Load the event_type to event_type_ids mapping into memory.
|
||||
|
@ -80,14 +77,6 @@ class EventTypeManager(BaseTableManager):
|
|||
|
||||
return results
|
||||
|
||||
def get_pending(self, event_type: str) -> EventTypes | None:
|
||||
"""Get pending EventTypes that have not be assigned ids yet.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self._pending.get(event_type)
|
||||
|
||||
def add_pending(self, db_event_type: EventTypes) -> None:
|
||||
"""Add a pending EventTypes that will be committed at the next interval.
|
||||
|
||||
|
@ -108,15 +97,6 @@ class EventTypeManager(BaseTableManager):
|
|||
self._id_map[event_type] = db_event_types.event_type_id
|
||||
self._pending.clear()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the event manager after the database has been reset or changed.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
self._id_map.clear()
|
||||
self._pending.clear()
|
||||
|
||||
def evict_purged(self, event_types: Iterable[str]) -> None:
|
||||
"""Evict purged event_types from the cache when they are no longer used.
|
||||
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
"""Support managing StateAttributes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.core import Event
|
||||
from homeassistant.helpers.entity import entity_sources
|
||||
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
|
||||
|
||||
from . import BaseLRUTableManager
|
||||
from ..const import SQLITE_MAX_BIND_VARS
|
||||
from ..db_schema import StateAttributes
|
||||
from ..queries import get_shared_attributes
|
||||
from ..util import chunked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core import Recorder
|
||||
|
||||
# The number of attribute ids to cache in memory
|
||||
#
|
||||
# Based on:
|
||||
# - The number of overlapping attributes
|
||||
# - How frequently states with overlapping attributes will change
|
||||
# - How much memory our low end hardware has
|
||||
CACHE_SIZE = 2048
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
|
||||
"""Manage the StateAttributes table."""
|
||||
|
||||
def __init__(
|
||||
self, recorder: Recorder, exclude_attributes_by_domain: dict[str, set[str]]
|
||||
) -> None:
|
||||
"""Initialize the event type manager."""
|
||||
super().__init__(recorder, CACHE_SIZE)
|
||||
self.active = True # always active
|
||||
self._exclude_attributes_by_domain = exclude_attributes_by_domain
|
||||
self._entity_sources = entity_sources(recorder.hass)
|
||||
|
||||
def serialize_from_event(self, event: Event) -> bytes | None:
|
||||
"""Serialize event data."""
|
||||
try:
|
||||
return StateAttributes.shared_attrs_bytes_from_event(
|
||||
event,
|
||||
self._entity_sources,
|
||||
self._exclude_attributes_by_domain,
|
||||
self.recorder.dialect_name,
|
||||
)
|
||||
except JSON_ENCODE_EXCEPTIONS as ex:
|
||||
_LOGGER.warning(
|
||||
"State is not JSON serializable: %s: %s",
|
||||
event.data.get("new_state"),
|
||||
ex,
|
||||
)
|
||||
return None
|
||||
|
||||
def load(self, events: list[Event], session: Session) -> None:
|
||||
"""Load the shared_attrs to attributes_ids mapping into memory from events.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
if hashes := {
|
||||
StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
|
||||
for event in events
|
||||
if (shared_attrs_bytes := self.serialize_from_event(event))
|
||||
}:
|
||||
self._load_from_hashes(hashes, session)
|
||||
|
||||
def get(self, shared_attr: str, data_hash: int, session: Session) -> int | None:
|
||||
"""Resolve shared_attrs to the attributes_id.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self.get_many(((shared_attr, data_hash),), session)[shared_attr]
|
||||
|
||||
def get_many(
|
||||
self, shared_attrs_data_hashes: Iterable[tuple[str, int]], session: Session
|
||||
) -> dict[str, int | None]:
|
||||
"""Resolve shared_attrs to attributes_ids.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
results: dict[str, int | None] = {}
|
||||
missing_hashes: set[int] = set()
|
||||
for shared_attrs, data_hash in shared_attrs_data_hashes:
|
||||
if (attributes_id := self._id_map.get(shared_attrs)) is None:
|
||||
missing_hashes.add(data_hash)
|
||||
|
||||
results[shared_attrs] = attributes_id
|
||||
|
||||
if not missing_hashes:
|
||||
return results
|
||||
|
||||
return results | self._load_from_hashes(missing_hashes, session)
|
||||
|
||||
def _load_from_hashes(
|
||||
self, hashes: Iterable[int], session: Session
|
||||
) -> dict[str, int | None]:
|
||||
"""Load the shared_attrs to attributes_ids mapping into memory from a list of hashes.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
results: dict[str, int | None] = {}
|
||||
with session.no_autoflush:
|
||||
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
|
||||
for attributes_id, shared_attrs in session.execute(
|
||||
get_shared_attributes(hashs_chunk)
|
||||
):
|
||||
results[shared_attrs] = self._id_map[shared_attrs] = cast(
|
||||
int, attributes_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def add_pending(self, db_state_attributes: StateAttributes) -> None:
|
||||
"""Add a pending StateAttributes that will be committed at the next interval.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
assert db_state_attributes.shared_attrs is not None
|
||||
shared_attrs: str = db_state_attributes.shared_attrs
|
||||
self._pending[shared_attrs] = db_state_attributes
|
||||
|
||||
def post_commit_pending(self) -> None:
|
||||
"""Call after commit to load the attributes_ids of the new StateAttributes into the LRU.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
for shared_attrs, db_state_attributes in self._pending.items():
|
||||
self._id_map[shared_attrs] = db_state_attributes.attributes_id
|
||||
self._pending.clear()
|
||||
|
||||
def evict_purged(self, attributes_ids: set[int]) -> None:
|
||||
"""Evict purged attributes_ids from the cache when they are no longer used.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
id_map = self._id_map
|
||||
state_attributes_ids_reversed = {
|
||||
attributes_id: shared_attrs
|
||||
for shared_attrs, attributes_id in id_map.items()
|
||||
}
|
||||
# Evict any purged data from the cache
|
||||
for purged_attributes_id in attributes_ids.intersection(
|
||||
state_attributes_ids_reversed
|
||||
):
|
||||
id_map.pop(state_attributes_ids_reversed[purged_attributes_id], None)
|
|
@ -4,12 +4,11 @@ from __future__ import annotations
|
|||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from lru import LRU # pylint: disable=no-name-in-module
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.core import Event
|
||||
|
||||
from . import BaseTableManager
|
||||
from . import BaseLRUTableManager
|
||||
from ..const import SQLITE_MAX_BIND_VARS
|
||||
from ..db_schema import StatesMeta
|
||||
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
|
||||
|
@ -21,15 +20,13 @@ if TYPE_CHECKING:
|
|||
CACHE_SIZE = 8192
|
||||
|
||||
|
||||
class StatesMetaManager(BaseTableManager):
|
||||
class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
|
||||
"""Manage the StatesMeta table."""
|
||||
|
||||
def __init__(self, recorder: Recorder) -> None:
|
||||
"""Initialize the states meta manager."""
|
||||
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
|
||||
self._pending: dict[str, StatesMeta] = {}
|
||||
self._did_first_load = False
|
||||
super().__init__(recorder)
|
||||
super().__init__(recorder, CACHE_SIZE)
|
||||
|
||||
def load(self, events: list[Event], session: Session) -> None:
|
||||
"""Load the entity_id to metadata_id mapping into memory.
|
||||
|
@ -112,14 +109,6 @@ class StatesMetaManager(BaseTableManager):
|
|||
|
||||
return results
|
||||
|
||||
def get_pending(self, entity_id: str) -> StatesMeta | None:
|
||||
"""Get pending StatesMeta that have not be assigned ids yet.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
return self._pending.get(entity_id)
|
||||
|
||||
def add_pending(self, db_states_meta: StatesMeta) -> None:
|
||||
"""Add a pending StatesMeta that will be committed at the next interval.
|
||||
|
||||
|
@ -140,15 +129,6 @@ class StatesMetaManager(BaseTableManager):
|
|||
self._id_map[entity_id] = db_states_meta.metadata_id
|
||||
self._pending.clear()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the states meta manager after the database has been reset or changed.
|
||||
|
||||
This call is not thread-safe and must be called from the
|
||||
recorder thread.
|
||||
"""
|
||||
self._id_map.clear()
|
||||
self._pending.clear()
|
||||
|
||||
def evict_purged(self, entity_ids: Iterable[str]) -> None:
|
||||
"""Evict purged event_types from the cache when they are no longer used.
|
||||
|
||||
|
|
|
@ -1872,9 +1872,11 @@ def test_deduplication_event_data_inside_commit_interval(
|
|||
assert all(event.data_id == first_data_id for event in events)
|
||||
|
||||
|
||||
# Patch STATE_ATTRIBUTES_ID_CACHE_SIZE since otherwise
|
||||
# Patch CACHE_SIZE since otherwise
|
||||
# the CI can fail because the test takes too long to run
|
||||
@patch("homeassistant.components.recorder.core.STATE_ATTRIBUTES_ID_CACHE_SIZE", 5)
|
||||
@patch(
|
||||
"homeassistant.components.recorder.table_managers.state_attributes.CACHE_SIZE", 5
|
||||
)
|
||||
def test_deduplication_state_attributes_inside_commit_interval(
|
||||
hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
|
@ -2159,4 +2161,8 @@ async def test_lru_increases_with_many_entities(
|
|||
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10))
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
assert recorder_mock._state_attributes_ids.get_size() == mock_entity_count * 2
|
||||
assert (
|
||||
recorder_mock.state_attributes_manager._id_map.get_size()
|
||||
== mock_entity_count * 2
|
||||
)
|
||||
assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2
|
||||
|
|
Loading…
Add table
Reference in a new issue