Reduce latency to find stats metadata (#89824)

This commit is contained in:
J. Nick Koston 2023-03-16 19:00:02 -10:00 committed by GitHub
parent 04a99fdbfc
commit f6f3565796
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 589 additions and 255 deletions

View file

@ -86,6 +86,7 @@ from .table_managers.event_types import EventTypeManager
from .table_managers.state_attributes import StateAttributesManager
from .table_managers.states import StatesManager
from .table_managers.states_meta import StatesMetaManager
from .table_managers.statistics_meta import StatisticsMetaManager
from .tasks import (
AdjustLRUSizeTask,
AdjustStatisticsTask,
@ -172,6 +173,7 @@ class Recorder(threading.Thread):
threading.Thread.__init__(self, name="Recorder")
self.hass = hass
self.thread_id: int | None = None
self.auto_purge = auto_purge
self.auto_repack = auto_repack
self.keep_days = keep_days
@ -208,6 +210,7 @@ class Recorder(threading.Thread):
self.state_attributes_manager = StateAttributesManager(
self, exclude_attributes_by_domain
)
self.statistics_meta_manager = StatisticsMetaManager(self)
self.event_session: Session | None = None
self._get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None
@ -613,6 +616,7 @@ class Recorder(threading.Thread):
def run(self) -> None:
"""Start processing events to save."""
self.thread_id = threading.get_ident()
setup_result = self._setup_recorder()
if not setup_result:
@ -668,7 +672,7 @@ class Recorder(threading.Thread):
"Database Migration Failed",
"recorder_database_migration",
)
self._activate_and_set_db_ready()
self.hass.add_job(self.async_set_db_ready)
self._shutdown()
return
@ -687,7 +691,14 @@ class Recorder(threading.Thread):
def _activate_and_set_db_ready(self) -> None:
"""Activate the table managers or schedule migrations and mark the db as ready."""
with session_scope(session=self.get_session()) as session:
with session_scope(session=self.get_session(), read_only=True) as session:
# Prime the statistics meta manager as soon as possible
# since we want the frontend queries to avoid a thundering
# herd of queries to find the statistics meta data if
# there are a lot of statistics graphs on the frontend.
if self.schema_version >= 23:
self.statistics_meta_manager.load(session)
if (
self.schema_version < 36
or session.execute(has_events_context_ids_to_migrate()).scalar()
@ -758,10 +769,11 @@ class Recorder(threading.Thread):
non_state_change_events.append(event_)
assert self.event_session is not None
self.event_data_manager.load(non_state_change_events, self.event_session)
self.event_type_manager.load(non_state_change_events, self.event_session)
self.states_meta_manager.load(state_change_events, self.event_session)
self.state_attributes_manager.load(state_change_events, self.event_session)
session = self.event_session
self.event_data_manager.load(non_state_change_events, session)
self.event_type_manager.load(non_state_change_events, session)
self.states_meta_manager.load(state_change_events, session)
self.state_attributes_manager.load(state_change_events, session)
def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None:
"""Process a task, guarding against exceptions to ensure the loop does not collapse."""
@ -1077,6 +1089,7 @@ class Recorder(threading.Thread):
self.event_data_manager.reset()
self.event_type_manager.reset()
self.states_meta_manager.reset()
self.statistics_meta_manager.reset()
if not self.event_session:
return

View file

@ -873,7 +873,7 @@ def _apply_update( # noqa: C901
# There may be duplicated statistics_meta entries, delete duplicates
# and try again
with session_scope(session=session_maker()) as session:
delete_statistics_meta_duplicates(session)
delete_statistics_meta_duplicates(instance, session)
_create_index(
session_maker, "statistics_meta", "ix_statistics_meta_statistic_id"
)

View file

@ -21,7 +21,7 @@ from sqlalchemy.engine import Engine
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql.lambdas import StatementLambdaElement
import voluptuous as vol
@ -132,16 +132,6 @@ QUERY_STATISTICS_SUMMARY_SUM = (
.label("rownum"),
)
QUERY_STATISTIC_META = (
StatisticsMeta.id,
StatisticsMeta.statistic_id,
StatisticsMeta.source,
StatisticsMeta.unit_of_measurement,
StatisticsMeta.has_mean,
StatisticsMeta.has_sum,
StatisticsMeta.name,
)
STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = {
**{unit: DataRateConverter for unit in DataRateConverter.VALID_UNITS},
@ -373,56 +363,6 @@ def get_start_time() -> datetime:
return last_period
def _update_or_add_metadata(
session: Session,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> int:
"""Get metadata_id for a statistic_id.
If the statistic_id is previously unknown, add it. If it's already known, update
metadata if needed.
Updating metadata source is not possible.
"""
statistic_id = new_metadata["statistic_id"]
if statistic_id not in old_metadata_dict:
meta = StatisticsMeta.from_meta(new_metadata)
session.add(meta)
session.flush() # Flush to get the metadata id assigned
_LOGGER.debug(
"Added new statistics metadata for %s, new_metadata: %s",
statistic_id,
new_metadata,
)
return meta.id
metadata_id, old_metadata = old_metadata_dict[statistic_id]
if (
old_metadata["has_mean"] != new_metadata["has_mean"]
or old_metadata["has_sum"] != new_metadata["has_sum"]
or old_metadata["name"] != new_metadata["name"]
or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"]
):
session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update(
{
StatisticsMeta.has_mean: new_metadata["has_mean"],
StatisticsMeta.has_sum: new_metadata["has_sum"],
StatisticsMeta.name: new_metadata["name"],
StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"],
},
synchronize_session=False,
)
_LOGGER.debug(
"Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s",
statistic_id,
old_metadata,
new_metadata,
)
return metadata_id
def _find_duplicates(
session: Session, table: type[StatisticsBase]
) -> tuple[list[int], list[dict]]:
@ -642,13 +582,16 @@ def _delete_statistics_meta_duplicates(session: Session) -> int:
return total_deleted_rows
def delete_statistics_meta_duplicates(session: Session) -> None:
def delete_statistics_meta_duplicates(instance: Recorder, session: Session) -> None:
"""Identify and delete duplicated statistics_meta.
This is used when migrating from schema version 28 to schema version 29.
"""
deleted_statistics_rows = _delete_statistics_meta_duplicates(session)
if deleted_statistics_rows:
statistics_meta_manager = instance.statistics_meta_manager
statistics_meta_manager.reset()
statistics_meta_manager.load(session)
_LOGGER.info(
"Deleted %s duplicated statistics_meta rows", deleted_statistics_rows
)
@ -750,6 +693,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
"""
start = dt_util.as_utc(start)
end = start + timedelta(minutes=5)
statistics_meta_manager = instance.statistics_meta_manager
# Return if we already have 5-minute statistics for the requested period
with session_scope(
@ -782,7 +726,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
# Insert collected statistics in the database
for stats in platform_stats:
metadata_id = _update_or_add_metadata(
_, metadata_id = statistics_meta_manager.update_or_add(
session, stats["meta"], current_metadata
)
_insert_statistics(
@ -877,28 +821,8 @@ def _update_statistics(
)
def _generate_get_metadata_stmt(
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
"""Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
if statistic_ids:
stmt += lambda q: q.where(
# https://github.com/python/mypy/issues/2608
StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type]
)
if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean":
stmt += lambda q: q.where(StatisticsMeta.has_mean == true())
elif statistic_type == "sum":
stmt += lambda q: q.where(StatisticsMeta.has_sum == true())
return stmt
def get_metadata_with_session(
instance: Recorder,
session: Session,
*,
statistic_ids: list[str] | None = None,
@ -908,31 +832,15 @@ def get_metadata_with_session(
"""Fetch meta data.
Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id.
If statistic_ids is given, fetch metadata only for the listed statistics_ids.
If statistic_type is given, fetch metadata only for statistic_ids supporting it.
"""
# Fetch metatadata from the database
stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source)
result = execute_stmt_lambda_element(session, stmt)
if not result:
return {}
return {
meta.statistic_id: (
meta.id,
{
"has_mean": meta.has_mean,
"has_sum": meta.has_sum,
"name": meta.name,
"source": meta.source,
"statistic_id": meta.statistic_id,
"unit_of_measurement": meta.unit_of_measurement,
},
)
for meta in result
}
return instance.statistics_meta_manager.get_many(
session,
statistic_ids=statistic_ids,
statistic_type=statistic_type,
statistic_source=statistic_source,
)
def get_metadata(
@ -945,6 +853,7 @@ def get_metadata(
"""Return metadata for statistic_ids."""
with session_scope(hass=hass, read_only=True) as session:
return get_metadata_with_session(
get_instance(hass),
session,
statistic_ids=statistic_ids,
statistic_type=statistic_type,
@ -952,17 +861,10 @@ def get_metadata(
)
def _clear_statistics_with_session(session: Session, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
with session_scope(session=instance.get_session()) as session:
_clear_statistics_with_session(session, statistic_ids)
instance.statistics_meta_manager.delete(session, statistic_ids)
def update_statistics_metadata(
@ -972,20 +874,20 @@ def update_statistics_metadata(
new_unit_of_measurement: str | None | UndefinedType,
) -> None:
"""Update statistics metadata for a statistic_id."""
statistics_meta_manager = instance.statistics_meta_manager
if new_unit_of_measurement is not UNDEFINED:
with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: new_unit_of_measurement})
if new_statistic_id is not UNDEFINED:
statistics_meta_manager.update_unit_of_measurement(
session, statistic_id, new_unit_of_measurement
)
if new_statistic_id is not UNDEFINED and new_statistic_id is not None:
with session_scope(
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == statistic_id)
& (StatisticsMeta.source == DOMAIN)
).update({StatisticsMeta.statistic_id: new_statistic_id})
statistics_meta_manager.update_statistic_id(
session, DOMAIN, statistic_id, new_statistic_id
)
def list_statistic_ids(
@ -1004,7 +906,7 @@ def list_statistic_ids(
# Query the database
with session_scope(hass=hass, read_only=True) as session:
metadata = get_metadata_with_session(
metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_type=statistic_type, statistic_ids=statistic_ids
)
@ -1609,11 +1511,13 @@ def statistic_during_period(
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id
if not (
metadata := get_metadata_with_session(session, statistic_ids=[statistic_id])
metadata := get_instance(hass).statistics_meta_manager.get(
session, statistic_id
)
):
return result
metadata_id = metadata[statistic_id][0]
metadata_id = metadata[0]
oldest_stat = _first_statistic(session, Statistics, metadata_id)
oldest_5_min_stat = None
@ -1724,7 +1628,7 @@ def statistic_during_period(
else:
result["change"] = None
state_unit = unit = metadata[statistic_id][1]["unit_of_measurement"]
state_unit = unit = metadata[1]["unit_of_measurement"]
if state := hass.states.get(statistic_id):
state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
convert = _get_statistic_to_display_unit_converter(unit, state_unit, units)
@ -1749,7 +1653,9 @@ def _statistics_during_period_with_session(
"""
metadata = None
# Fetch metadata for the given (or all) statistic_ids
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_ids=statistic_ids
)
if not metadata:
return {}
@ -1885,7 +1791,9 @@ def _get_last_statistics(
statistic_ids = [statistic_id]
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_ids=statistic_ids
)
if not metadata:
return {}
metadata_id = metadata[statistic_id][0]
@ -1973,7 +1881,9 @@ def get_latest_short_term_statistics(
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_ids
if not metadata:
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_ids=statistic_ids
)
if not metadata:
return {}
metadata_ids = [
@ -2318,16 +2228,20 @@ def _filter_unique_constraint_integrity_error(
def _import_statistics_with_session(
instance: Recorder,
session: Session,
metadata: StatisticMetaData,
statistics: Iterable[StatisticData],
table: type[StatisticsBase],
) -> bool:
"""Import statistics to the database."""
old_metadata_dict = get_metadata_with_session(
statistics_meta_manager = instance.statistics_meta_manager
old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[metadata["statistic_id"]]
)
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
_, metadata_id = statistics_meta_manager.update_or_add(
session, metadata, old_metadata_dict
)
for stat in statistics:
if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]):
_update_statistics(session, table, stat_id, stat)
@ -2350,7 +2264,9 @@ def import_statistics(
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
return _import_statistics_with_session(session, metadata, statistics, table)
return _import_statistics_with_session(
instance, session, metadata, statistics, table
)
@retryable_database_job("adjust_statistics")
@ -2364,7 +2280,9 @@ def adjust_statistics(
"""Process an add_statistics job."""
with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(session, statistic_ids=[statistic_id])
metadata = instance.statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id]
)
if statistic_id not in metadata:
return True
@ -2423,10 +2341,9 @@ def change_statistics_unit(
old_unit: str,
) -> None:
"""Change statistics unit for a statistic_id."""
statistics_meta_manager = instance.statistics_meta_manager
with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]).get(
statistic_id
)
metadata = statistics_meta_manager.get(session, statistic_id)
# Guard against the statistics being removed or updated before the
# change_statistics_unit job executes
@ -2447,9 +2364,10 @@ def change_statistics_unit(
)
for table in tables:
_change_statistics_unit_for_table(session, table, metadata_id, convert)
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: new_unit})
statistics_meta_manager.update_unit_of_measurement(
session, statistic_id, new_unit
)
@callback
@ -2495,16 +2413,19 @@ def _validate_db_schema_utf8(
"statistic_id": statistic_id,
"unit_of_measurement": None,
}
statistics_meta_manager = instance.statistics_meta_manager
# Try inserting some metadata which needs utfmb4 support
try:
with session_scope(session=session_maker()) as session:
old_metadata_dict = get_metadata_with_session(
old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id]
)
try:
_update_or_add_metadata(session, metadata, old_metadata_dict)
_clear_statistics_with_session(session, statistic_ids=[statistic_id])
statistics_meta_manager.update_or_add(
session, metadata, old_metadata_dict
)
statistics_meta_manager.delete(session, statistic_ids=[statistic_id])
except OperationalError as err:
if err.orig and err.orig.args[0] == 1366:
_LOGGER.debug(
@ -2524,6 +2445,7 @@ def _validate_db_schema(
) -> set[str]:
"""Do some basic checks for common schema errors caused by manual migration."""
schema_errors: set[str] = set()
statistics_meta_manager = instance.statistics_meta_manager
# Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL
if instance.dialect_name not in (
@ -2586,7 +2508,9 @@ def _validate_db_schema(
try:
with session_scope(session=session_maker()) as session:
for table in tables:
_import_statistics_with_session(session, metadata, (statistics,), table)
_import_statistics_with_session(
instance, session, metadata, (statistics,), table
)
stored_statistics = _statistics_during_period_with_session(
hass,
session,
@ -2625,7 +2549,7 @@ def _validate_db_schema(
table.__tablename__,
"µs precision",
)
_clear_statistics_with_session(session, statistic_ids=[statistic_id])
statistics_meta_manager.delete(session, statistic_ids=[statistic_id])
except Exception as exc: # pylint: disable=broad-except
_LOGGER.exception("Error when validating DB schema: %s", exc)

View file

@ -14,7 +14,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventData
from ..queries import get_shared_event_datas
from ..util import chunked
from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING:
from ..core import Recorder
@ -96,8 +96,8 @@ class EventDataManager(BaseLRUTableManager[EventData]):
results: dict[str, int | None] = {}
with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for data_id, shared_data in session.execute(
get_shared_event_datas(hashs_chunk)
for data_id, shared_data in execute_stmt_lambda_element(
session, get_shared_event_datas(hashs_chunk)
):
results[shared_data] = self._id_map[shared_data] = cast(
int, data_id

View file

@ -12,7 +12,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventTypes
from ..queries import find_event_type_ids
from ..util import chunked
from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING:
from ..core import Recorder
@ -68,8 +68,8 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
with session.no_autoflush:
for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS):
for event_type_id, event_type in session.execute(
find_event_type_ids(missing_chunk)
for event_type_id, event_type in execute_stmt_lambda_element(
session, find_event_type_ids(missing_chunk)
):
results[event_type] = self._id_map[event_type] = cast(
int, event_type_id

View file

@ -15,7 +15,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StateAttributes
from ..queries import get_shared_attributes
from ..util import chunked
from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING:
from ..core import Recorder
@ -113,8 +113,8 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
results: dict[str, int | None] = {}
with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for attributes_id, shared_attrs in session.execute(
get_shared_attributes(hashs_chunk)
for attributes_id, shared_attrs in execute_stmt_lambda_element(
session, get_shared_attributes(hashs_chunk)
):
results[shared_attrs] = self._id_map[shared_attrs] = cast(
int, attributes_id

View file

@ -12,7 +12,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StatesMeta
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
from ..util import chunked
from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING:
from ..core import Recorder
@ -98,8 +98,8 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
with session.no_autoflush:
for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS):
for metadata_id, entity_id in session.execute(
find_states_metadata_ids(missing_chunk)
for metadata_id, entity_id in execute_stmt_lambda_element(
session, find_states_metadata_ids(missing_chunk)
):
metadata_id = cast(int, metadata_id)
results[entity_id] = metadata_id

View file

@ -0,0 +1,322 @@
"""Support managing StatesMeta."""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Literal, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import lambda_stmt, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import true
from sqlalchemy.sql.lambdas import StatementLambdaElement
from ..db_schema import StatisticsMeta
from ..models import StatisticMetaData
from ..util import execute_stmt_lambda_element
if TYPE_CHECKING:
from ..core import Recorder
CACHE_SIZE = 8192
_LOGGER = logging.getLogger(__name__)
QUERY_STATISTIC_META = (
StatisticsMeta.id,
StatisticsMeta.statistic_id,
StatisticsMeta.source,
StatisticsMeta.unit_of_measurement,
StatisticsMeta.has_mean,
StatisticsMeta.has_sum,
StatisticsMeta.name,
)
def _generate_get_metadata_stmt(
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
"""Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
if statistic_ids:
stmt += lambda q: q.where(
# https://github.com/python/mypy/issues/2608
StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type]
)
if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean":
stmt += lambda q: q.where(StatisticsMeta.has_mean == true())
elif statistic_type == "sum":
stmt += lambda q: q.where(StatisticsMeta.has_sum == true())
return stmt
def _statistics_meta_to_id_statistics_metadata(
meta: StatisticsMeta,
) -> tuple[int, StatisticMetaData]:
"""Convert StatisticsMeta tuple of metadata_id and StatisticMetaData."""
return (
meta.id,
{
"has_mean": meta.has_mean, # type: ignore[typeddict-item]
"has_sum": meta.has_sum, # type: ignore[typeddict-item]
"name": meta.name,
"source": meta.source, # type: ignore[typeddict-item]
"statistic_id": meta.statistic_id, # type: ignore[typeddict-item]
"unit_of_measurement": meta.unit_of_measurement,
},
)
class StatisticsMetaManager:
"""Manage the StatisticsMeta table."""
def __init__(self, recorder: Recorder) -> None:
"""Initialize the statistics meta manager."""
self.recorder = recorder
self._stat_id_to_id_meta: dict[str, tuple[int, StatisticMetaData]] = LRU(
CACHE_SIZE
)
def _clear_cache(self, statistic_ids: list[str]) -> None:
"""Clear the cache."""
for statistic_id in statistic_ids:
self._stat_id_to_id_meta.pop(statistic_id, None)
def _get_from_database(
self,
session: Session,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Fetch meta data and process it into results and/or cache."""
# Only update the cache if we are in the recorder thread and there are no
# new objects that are not yet committed to the database in the session.
update_cache = (
not session.new
and not session.dirty
and self.recorder.thread_id == threading.get_ident()
)
results: dict[str, tuple[int, StatisticMetaData]] = {}
with session.no_autoflush:
stat_id_to_id_meta = self._stat_id_to_id_meta
for row in execute_stmt_lambda_element(
session,
_generate_get_metadata_stmt(
statistic_ids, statistic_type, statistic_source
),
):
statistics_meta = cast(StatisticsMeta, row)
id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta)
statistic_id = cast(str, statistics_meta.statistic_id)
results[statistic_id] = id_meta
if update_cache:
stat_id_to_id_meta[statistic_id] = id_meta
return results
def _assert_in_recorder_thread(self) -> None:
"""Assert that we are in the recorder thread."""
if self.recorder.thread_id != threading.get_ident():
raise RuntimeError("Detected unsafe call not in recorder thread")
def _add_metadata(
self, session: Session, statistic_id: str, new_metadata: StatisticMetaData
) -> int:
"""Add metadata to the database.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
meta = StatisticsMeta.from_meta(new_metadata)
session.add(meta)
# Flush to assign an ID
session.flush()
_LOGGER.debug(
"Added new statistics metadata for %s, new_metadata: %s",
statistic_id,
new_metadata,
)
return meta.id
def _update_metadata(
self,
session: Session,
statistic_id: str,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]:
"""Update metadata in the database.
This call is not thread-safe and must be called from the
recorder thread.
"""
metadata_id, old_metadata = old_metadata_dict[statistic_id]
if not (
old_metadata["has_mean"] != new_metadata["has_mean"]
or old_metadata["has_sum"] != new_metadata["has_sum"]
or old_metadata["name"] != new_metadata["name"]
or old_metadata["unit_of_measurement"]
!= new_metadata["unit_of_measurement"]
):
return False, metadata_id
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update(
{
StatisticsMeta.has_mean: new_metadata["has_mean"],
StatisticsMeta.has_sum: new_metadata["has_sum"],
StatisticsMeta.name: new_metadata["name"],
StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"],
},
synchronize_session=False,
)
self._clear_cache([statistic_id])
_LOGGER.debug(
"Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s",
statistic_id,
old_metadata,
new_metadata,
)
return True, metadata_id
def load(self, session: Session) -> None:
"""Load the statistic_id to metadata_id mapping into memory.
This call is not thread-safe and must be called from the
recorder thread.
"""
self.get_many(session)
def get(
self, session: Session, statistic_id: str
) -> tuple[int, StatisticMetaData] | None:
"""Resolve statistic_id to the metadata_id."""
return self.get_many(session, [statistic_id]).get(statistic_id)
def get_many(
self,
session: Session,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Fetch meta data.
Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id.
If statistic_ids is given, fetch metadata only for the listed statistics_ids.
If statistic_type is given, fetch metadata only for statistic_ids supporting it.
"""
if statistic_ids is None:
# Fetch metadata from the database
return self._get_from_database(
session,
statistic_type=statistic_type,
statistic_source=statistic_source,
)
if statistic_type is not None or statistic_source is not None:
# This was originally implemented but we never used it
# so the code was ripped out to reduce the maintenance
# burden.
raise ValueError(
"Providing statistic_type and statistic_source is mutually exclusive of statistic_ids"
)
results: dict[str, tuple[int, StatisticMetaData]] = {}
missing_statistic_id: list[str] = []
for statistic_id in statistic_ids:
if id_meta := self._stat_id_to_id_meta.get(statistic_id):
results[statistic_id] = id_meta
else:
missing_statistic_id.append(statistic_id)
if not missing_statistic_id:
return results
# Fetch metadata from the database
return results | self._get_from_database(
session, statistic_ids=missing_statistic_id
)
def update_or_add(
self,
session: Session,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]:
"""Get metadata_id for a statistic_id.
If the statistic_id is previously unknown, add it. If it's already known, update
metadata if needed.
Updating metadata source is not possible.
Returns a tuple of (updated, metadata_id).
updated is True if the metadata was updated, False if it was not updated.
This call is not thread-safe and must be called from the
recorder thread.
"""
statistic_id = new_metadata["statistic_id"]
if statistic_id not in old_metadata_dict:
return True, self._add_metadata(session, statistic_id, new_metadata)
return self._update_metadata(
session, statistic_id, new_metadata, old_metadata_dict
)
def update_unit_of_measurement(
self, session: Session, statistic_id: str, new_unit: str | None
) -> None:
"""Update the unit of measurement for a statistic_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: new_unit})
self._clear_cache([statistic_id])
def update_statistic_id(
self,
session: Session,
source: str,
old_statistic_id: str,
new_statistic_id: str,
) -> None:
"""Update the statistic_id for a statistic_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == old_statistic_id)
& (StatisticsMeta.source == source)
).update({StatisticsMeta.statistic_id: new_statistic_id})
self._clear_cache([old_statistic_id, new_statistic_id])
def delete(self, session: Session, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
self._clear_cache(statistic_ids)
def reset(self) -> None:
"""Reset the cache."""
self._stat_id_to_id_meta = {}

View file

@ -145,31 +145,36 @@ def _parse_float(state: str) -> float:
return fstate
def _float_or_none(state: str) -> float | None:
"""Return a float or None."""
try:
return _parse_float(state)
except (ValueError, TypeError):
return None
def _entity_history_to_float_and_state(
entity_history: Iterable[State],
) -> list[tuple[float, State]]:
"""Return a list of (float, state) tuples for the given entity."""
return [
(fstate, state)
for state in entity_history
if (fstate := _float_or_none(state.state)) is not None
]
def _normalize_states(
hass: HomeAssistant,
session: Session,
old_metadatas: dict[str, tuple[int, StatisticMetaData]],
entity_history: Iterable[State],
fstates: list[tuple[float, State]],
entity_id: str,
) -> tuple[str | None, list[tuple[float, State]]]:
"""Normalize units."""
old_metadata = old_metadatas[entity_id][1] if entity_id in old_metadatas else None
state_unit: str | None = None
fstates: list[tuple[float, State]] = []
for state in entity_history:
try:
fstate = _parse_float(state.state)
except (ValueError, TypeError): # TypeError to guard for NULL state in DB
continue
fstates.append((fstate, state))
if not fstates:
return None, fstates
state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT)
statistics_unit: str | None
state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT)
old_metadata = old_metadatas[entity_id][1] if entity_id in old_metadatas else None
if not old_metadata:
# We've not seen this sensor before, the first valid state determines the unit
# used for statistics
@ -379,7 +384,15 @@ def compile_statistics(
Note: This will query the database and must not be run in the event loop
"""
with recorder_util.session_scope(hass=hass) as session:
# There is already an active session when this code is called since
# it is called from the recorder statistics. We need to make sure
# this session never gets committed since it would be out of sync
# with the recorder statistics session so we mark it as read only.
#
# If we ever need to write to the database from this function we
# will need to refactor the recorder statistics to use a single
# session.
with recorder_util.session_scope(hass=hass, read_only=True) as session:
compiled = _compile_statistics(hass, session, start, end)
return compiled
@ -395,10 +408,6 @@ def _compile_statistics( # noqa: C901
sensor_states = _get_sensor_states(hass)
wanted_statistics = _wanted_statistics(sensor_states)
old_metadatas = statistics.get_metadata_with_session(
session, statistic_ids=[i.entity_id for i in sensor_states]
)
# Get history between start and end
entities_full_history = [
i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id]
@ -427,34 +436,41 @@ def _compile_statistics( # noqa: C901
entity_ids=entities_significant_history,
)
history_list = {**history_list, **_history_list}
# If there are no recent state changes, the sensor's state may already be pruned
# from the recorder. Get the state from the state machine instead.
for _state in sensor_states:
if _state.entity_id not in history_list:
history_list[_state.entity_id] = [_state]
to_process = []
to_query = []
entities_with_float_states: dict[str, list[tuple[float, State]]] = {}
for _state in sensor_states:
entity_id = _state.entity_id
if entity_id not in history_list:
# If there are no recent state changes, the sensor's state may already be pruned
# from the recorder. Get the state from the state machine instead.
if not (entity_history := history_list.get(entity_id, [_state])):
continue
if not (float_states := _entity_history_to_float_and_state(entity_history)):
continue
entities_with_float_states[entity_id] = float_states
entity_history = history_list[entity_id]
statistics_unit, fstates = _normalize_states(
# Only lookup metadata for entities that have valid float states
# since it will result in cache misses for statistic_ids
# that are not in the metadata table and we are not working
# with them anyway.
old_metadatas = statistics.get_metadata_with_session(
get_instance(hass), session, statistic_ids=list(entities_with_float_states)
)
to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = []
to_query: list[str] = []
for _state in sensor_states:
entity_id = _state.entity_id
if not (maybe_float_states := entities_with_float_states.get(entity_id)):
continue
statistics_unit, valid_float_states = _normalize_states(
hass,
session,
old_metadatas,
entity_history,
maybe_float_states,
entity_id,
)
if not fstates:
if not valid_float_states:
continue
state_class = _state.attributes[ATTR_STATE_CLASS]
to_process.append((entity_id, statistics_unit, state_class, fstates))
state_class: str = _state.attributes[ATTR_STATE_CLASS]
to_process.append((entity_id, statistics_unit, state_class, valid_float_states))
if "sum" in wanted_statistics[entity_id]:
to_query.append(entity_id)
@ -465,7 +481,7 @@ def _compile_statistics( # noqa: C901
entity_id,
statistics_unit,
state_class,
fstates,
valid_float_states,
) in to_process:
# Check metadata
if old_metadata := old_metadatas.get(entity_id):
@ -507,20 +523,20 @@ def _compile_statistics( # noqa: C901
if "max" in wanted_statistics[entity_id]:
stat["max"] = max(
*itertools.islice(
zip(*fstates), # type: ignore[typeddict-item]
zip(*valid_float_states), # type: ignore[typeddict-item]
1,
)
)
if "min" in wanted_statistics[entity_id]:
stat["min"] = min(
*itertools.islice(
zip(*fstates), # type: ignore[typeddict-item]
zip(*valid_float_states), # type: ignore[typeddict-item]
1,
)
)
if "mean" in wanted_statistics[entity_id]:
stat["mean"] = _time_weighted_average(fstates, start, end)
stat["mean"] = _time_weighted_average(valid_float_states, start, end)
if "sum" in wanted_statistics[entity_id]:
last_reset = old_last_reset = None
@ -535,7 +551,7 @@ def _compile_statistics( # noqa: C901
new_state = old_state = last_stat["state"]
_sum = last_stat["sum"] or 0.0
for fstate, state in fstates:
for fstate, state in valid_float_states:
reset = False
if (
state_class != SensorStateClass.TOTAL_INCREASING

View file

@ -0,0 +1 @@
"""Tests for the recorder table managers."""

View file

@ -0,0 +1,53 @@
"""The tests for the Recorder component."""
from __future__ import annotations
import pytest
from homeassistant.components import recorder
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
from tests.typing import RecorderInstanceGenerator
async def test_passing_mutually_exclusive_options_to_get_many(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test passing mutually exclusive options to get_many."""
instance = await async_setup_recorder_instance(
hass, {recorder.CONF_COMMIT_INTERVAL: 0}
)
with session_scope(session=instance.get_session()) as session:
with pytest.raises(ValueError):
instance.statistics_meta_manager.get_many(
session,
statistic_ids=["light.kitchen"],
statistic_type="mean",
)
with pytest.raises(ValueError):
instance.statistics_meta_manager.get_many(
session, statistic_ids=["light.kitchen"], statistic_source="sensor"
)
assert (
instance.statistics_meta_manager.get_many(
session,
statistic_ids=["light.kitchen"],
)
== {}
)
async def test_unsafe_calls_to_statistics_meta_manager(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test we raise when trying to call non-threadsafe functions on statistics_meta_manager."""
instance = await async_setup_recorder_instance(
hass, {recorder.CONF_COMMIT_INTERVAL: 0}
)
with session_scope(session=instance.get_session()) as session, pytest.raises(
RuntimeError, match="Detected unsafe call not in recorder thread"
):
instance.statistics_meta_manager.delete(
session,
statistic_ids=["light.kitchen"],
)

View file

@ -564,32 +564,6 @@ def _add_entities(hass, entity_ids):
return states
def _add_events(hass, events):
with session_scope(hass=hass) as session:
session.query(Events).delete(synchronize_session=False)
for event_type in events:
hass.bus.fire(event_type)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
events = []
for event, event_data, event_types in (
session.query(Events, EventData, EventTypes)
.outerjoin(EventTypes, (Events.event_type_id == EventTypes.event_type_id))
.outerjoin(EventData, Events.data_id == EventData.data_id)
):
event = cast(Events, event)
event_data = cast(EventData, event_data)
event_types = cast(EventTypes, event_types)
native_event = event.to_native()
if event_data:
native_event.data = event_data.to_native()
native_event.event_type = event_types.event_type
events.append(native_event)
return events
def _state_with_context(hass, entity_id):
# We don't restore context unless we need it by joining the
# events table on the event_id for state_changed events
@ -646,25 +620,53 @@ def test_saving_state_incl_entities(
assert _state_with_context(hass, "test2.recorder").as_dict() == states[0].as_dict()
def test_saving_event_exclude_event_type(
hass_recorder: Callable[..., HomeAssistant]
async def test_saving_event_exclude_event_type(
async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant,
) -> None:
"""Test saving and restoring an event."""
hass = hass_recorder(
{
"exclude": {
"event_types": [
"service_registered",
"homeassistant_start",
"component_loaded",
"core_config_updated",
"homeassistant_started",
"test",
]
}
config = {
"exclude": {
"event_types": [
"service_registered",
"homeassistant_start",
"component_loaded",
"core_config_updated",
"homeassistant_started",
"test",
]
}
)
events = _add_events(hass, ["test", "test2"])
}
instance = await async_setup_recorder_instance(hass, config)
events = ["test", "test2"]
for event_type in events:
hass.bus.async_fire(event_type)
await async_wait_recording_done(hass)
def _get_events(hass: HomeAssistant, event_types: list[str]) -> list[Event]:
with session_scope(hass=hass) as session:
events = []
for event, event_data, event_types in (
session.query(Events, EventData, EventTypes)
.outerjoin(
EventTypes, (Events.event_type_id == EventTypes.event_type_id)
)
.outerjoin(EventData, Events.data_id == EventData.data_id)
.where(EventTypes.event_type.in_(event_types))
):
event = cast(Events, event)
event_data = cast(EventData, event_data)
event_types = cast(EventTypes, event_types)
native_event = event.to_native()
if event_data:
native_event.data = event_data.to_native()
native_event.event_type = event_types.event_type
events.append(native_event)
return events
events = await instance.async_add_executor_job(_get_events, hass, ["test", "test2"])
assert len(events) == 1
assert events[0].event_type == "test2"

View file

@ -22,12 +22,10 @@ from homeassistant.components.recorder.models import (
)
from homeassistant.components.recorder.statistics import (
STATISTIC_UNIT_TO_UNIT_CONVERTER,
_generate_get_metadata_stmt,
_generate_max_mean_min_statistic_in_sub_period_stmt,
_generate_statistics_at_time_stmt,
_generate_statistics_during_period_stmt,
_statistics_during_period_with_session,
_update_or_add_metadata,
async_add_external_statistics,
async_import_statistics,
delete_statistics_duplicates,
@ -38,6 +36,10 @@ from homeassistant.components.recorder.statistics import (
get_metadata,
list_statistic_ids,
)
from homeassistant.components.recorder.table_managers.statistics_meta import (
StatisticsMetaManager,
_generate_get_metadata_stmt,
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.components.sensor import UNIT_CONVERTERS
from homeassistant.const import UnitOfTemperature
@ -1520,7 +1522,8 @@ def test_delete_metadata_duplicates_no_duplicates(
hass = hass_recorder()
wait_recording_done(hass)
with session_scope(hass=hass) as session:
delete_statistics_meta_duplicates(session)
instance = recorder.get_instance(hass)
delete_statistics_meta_duplicates(instance, session)
assert "duplicated statistics_meta rows" not in caplog.text
@ -1562,9 +1565,9 @@ async def test_validate_db_schema_fix_utf8_issue(
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"
), patch(
"homeassistant.components.recorder.statistics._update_or_add_metadata",
"homeassistant.components.recorder.table_managers.statistics_meta.StatisticsMetaManager.update_or_add",
wraps=StatisticsMetaManager.update_or_add,
side_effect=[utf8_error, DEFAULT, DEFAULT],
wraps=_update_or_add_metadata,
):
await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)