Complete strict typing for recorder (#71274)
* Complete strict typing for recorder * update tests * Update tests/components/recorder/test_migrate.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Update tests/components/recorder/test_migrate.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Remove the asserts * remove ignore comments Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
13ce0a7d6a
commit
eb77f8db85
10 changed files with 166 additions and 309 deletions
|
@ -177,23 +177,7 @@ homeassistant.components.pure_energie.*
|
|||
homeassistant.components.rainmachine.*
|
||||
homeassistant.components.rdw.*
|
||||
homeassistant.components.recollect_waste.*
|
||||
homeassistant.components.recorder
|
||||
homeassistant.components.recorder.const
|
||||
homeassistant.components.recorder.core
|
||||
homeassistant.components.recorder.backup
|
||||
homeassistant.components.recorder.executor
|
||||
homeassistant.components.recorder.history
|
||||
homeassistant.components.recorder.models
|
||||
homeassistant.components.recorder.pool
|
||||
homeassistant.components.recorder.purge
|
||||
homeassistant.components.recorder.repack
|
||||
homeassistant.components.recorder.run_history
|
||||
homeassistant.components.recorder.services
|
||||
homeassistant.components.recorder.statistics
|
||||
homeassistant.components.recorder.system_health
|
||||
homeassistant.components.recorder.tasks
|
||||
homeassistant.components.recorder.util
|
||||
homeassistant.components.recorder.websocket_api
|
||||
homeassistant.components.recorder.*
|
||||
homeassistant.components.remote.*
|
||||
homeassistant.components.renault.*
|
||||
homeassistant.components.ridwell.*
|
||||
|
|
|
@ -171,7 +171,7 @@ class Recorder(threading.Thread):
|
|||
self._pending_event_data: dict[str, EventData] = {}
|
||||
self._pending_expunge: list[States] = []
|
||||
self.event_session: Session | None = None
|
||||
self.get_session: Callable[[], Session] | None = None
|
||||
self._get_session: Callable[[], Session] | None = None
|
||||
self._completed_first_database_setup: bool | None = None
|
||||
self.async_migration_event = asyncio.Event()
|
||||
self.migration_in_progress = False
|
||||
|
@ -205,6 +205,12 @@ class Recorder(threading.Thread):
|
|||
"""Return if the recorder is recording."""
|
||||
return self._event_listener is not None
|
||||
|
||||
def get_session(self) -> Session:
|
||||
"""Get a new sqlalchemy session."""
|
||||
if self._get_session is None:
|
||||
raise RuntimeError("The database connection has not been established")
|
||||
return self._get_session()
|
||||
|
||||
def queue_task(self, task: RecorderTask) -> None:
|
||||
"""Add a task to the recorder queue."""
|
||||
self._queue.put(task)
|
||||
|
@ -459,7 +465,7 @@ class Recorder(threading.Thread):
|
|||
@callback
|
||||
def _async_setup_periodic_tasks(self) -> None:
|
||||
"""Prepare periodic tasks."""
|
||||
if self.hass.is_stopping or not self.get_session:
|
||||
if self.hass.is_stopping or not self._get_session:
|
||||
# Home Assistant is shutting down
|
||||
return
|
||||
|
||||
|
@ -591,7 +597,7 @@ class Recorder(threading.Thread):
|
|||
while tries <= self.db_max_retries:
|
||||
try:
|
||||
self._setup_connection()
|
||||
return migration.get_schema_version(self)
|
||||
return migration.get_schema_version(self.get_session)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error during connection setup: %s (retrying in %s seconds)",
|
||||
|
@ -619,7 +625,9 @@ class Recorder(threading.Thread):
|
|||
self.hass.add_job(self._async_migration_started)
|
||||
|
||||
try:
|
||||
migration.migrate_schema(self, current_version)
|
||||
migration.migrate_schema(
|
||||
self.hass, self.engine, self.get_session, current_version
|
||||
)
|
||||
except exc.DatabaseError as err:
|
||||
if self._handle_database_error(err):
|
||||
return True
|
||||
|
@ -896,7 +904,6 @@ class Recorder(threading.Thread):
|
|||
|
||||
def _open_event_session(self) -> None:
|
||||
"""Open the event session."""
|
||||
assert self.get_session is not None
|
||||
self.event_session = self.get_session()
|
||||
self.event_session.expire_on_commit = False
|
||||
|
||||
|
@ -1011,7 +1018,7 @@ class Recorder(threading.Thread):
|
|||
sqlalchemy_event.listen(self.engine, "connect", setup_recorder_connection)
|
||||
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.get_session = scoped_session(sessionmaker(bind=self.engine, future=True))
|
||||
self._get_session = scoped_session(sessionmaker(bind=self.engine, future=True))
|
||||
_LOGGER.debug("Connected to recorder database")
|
||||
|
||||
def _close_connection(self) -> None:
|
||||
|
@ -1019,11 +1026,10 @@ class Recorder(threading.Thread):
|
|||
assert self.engine is not None
|
||||
self.engine.dispose()
|
||||
self.engine = None
|
||||
self.get_session = None
|
||||
self._get_session = None
|
||||
|
||||
def _setup_run(self) -> None:
|
||||
"""Log the start of the current run and schedule any needed jobs."""
|
||||
assert self.get_session is not None
|
||||
with session_scope(session=self.get_session()) as session:
|
||||
end_incomplete_runs(session, self.run_history.recording_start)
|
||||
self.run_history.start(session)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
"""Schema migration helpers."""
|
||||
from collections.abc import Callable, Iterable
|
||||
import contextlib
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import (
|
||||
DatabaseError,
|
||||
InternalError,
|
||||
|
@ -13,9 +15,12 @@ from sqlalchemy.exc import (
|
|||
ProgrammingError,
|
||||
SQLAlchemyError,
|
||||
)
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.schema import AddConstraint, DropConstraint
|
||||
from sqlalchemy.sql.expression import true
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .models import (
|
||||
SCHEMA_VERSION,
|
||||
TABLE_STATES,
|
||||
|
@ -33,7 +38,7 @@ from .util import session_scope
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def raise_if_exception_missing_str(ex, match_substrs):
|
||||
def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) -> None:
|
||||
"""Raise an exception if the exception and cause do not contain the match substrs."""
|
||||
lower_ex_strs = [str(ex).lower(), str(ex.__cause__).lower()]
|
||||
for str_sub in match_substrs:
|
||||
|
@ -44,10 +49,9 @@ def raise_if_exception_missing_str(ex, match_substrs):
|
|||
raise ex
|
||||
|
||||
|
||||
def get_schema_version(instance: Any) -> int:
|
||||
def get_schema_version(session_maker: Callable[[], Session]) -> int:
|
||||
"""Get the schema version."""
|
||||
assert instance.get_session is not None
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
res = (
|
||||
session.query(SchemaChanges)
|
||||
.order_by(SchemaChanges.change_id.desc())
|
||||
|
@ -61,7 +65,7 @@ def get_schema_version(instance: Any) -> int:
|
|||
"No schema version found. Inspected version: %s", current_version
|
||||
)
|
||||
|
||||
return current_version
|
||||
return cast(int, current_version)
|
||||
|
||||
|
||||
def schema_is_current(current_version: int) -> bool:
|
||||
|
@ -69,21 +73,27 @@ def schema_is_current(current_version: int) -> bool:
|
|||
return current_version == SCHEMA_VERSION
|
||||
|
||||
|
||||
def migrate_schema(instance: Any, current_version: int) -> None:
|
||||
def migrate_schema(
|
||||
hass: HomeAssistant,
|
||||
engine: Engine,
|
||||
session_maker: Callable[[], Session],
|
||||
current_version: int,
|
||||
) -> None:
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
assert instance.get_session is not None
|
||||
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
||||
for version in range(current_version, SCHEMA_VERSION):
|
||||
new_version = version + 1
|
||||
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
|
||||
_apply_update(instance, new_version, current_version)
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
_apply_update(hass, engine, session_maker, new_version, current_version)
|
||||
with session_scope(session=session_maker()) as session:
|
||||
session.add(SchemaChanges(schema_version=new_version))
|
||||
|
||||
_LOGGER.info("Upgrade to version %s done", new_version)
|
||||
|
||||
|
||||
def _create_index(instance, table_name, index_name):
|
||||
def _create_index(
|
||||
session_maker: Callable[[], Session], table_name: str, index_name: str
|
||||
) -> None:
|
||||
"""Create an index for the specified table.
|
||||
|
||||
The index name should match the name given for the index
|
||||
|
@ -104,7 +114,7 @@ def _create_index(instance, table_name, index_name):
|
|||
"be patient!",
|
||||
index_name,
|
||||
)
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
index.create(connection)
|
||||
|
@ -117,7 +127,9 @@ def _create_index(instance, table_name, index_name):
|
|||
_LOGGER.debug("Finished creating %s", index_name)
|
||||
|
||||
|
||||
def _drop_index(instance, table_name, index_name):
|
||||
def _drop_index(
|
||||
session_maker: Callable[[], Session], table_name: str, index_name: str
|
||||
) -> None:
|
||||
"""Drop an index from a specified table.
|
||||
|
||||
There is no universal way to do something like `DROP INDEX IF EXISTS`
|
||||
|
@ -132,7 +144,7 @@ def _drop_index(instance, table_name, index_name):
|
|||
success = False
|
||||
|
||||
# Engines like DB2/Oracle
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(text(f"DROP INDEX {index_name}"))
|
||||
|
@ -143,7 +155,7 @@ def _drop_index(instance, table_name, index_name):
|
|||
|
||||
# Engines like SQLite, SQL Server
|
||||
if not success:
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
|
@ -160,7 +172,7 @@ def _drop_index(instance, table_name, index_name):
|
|||
|
||||
if not success:
|
||||
# Engines like MySQL, MS Access
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
|
@ -194,7 +206,9 @@ def _drop_index(instance, table_name, index_name):
|
|||
)
|
||||
|
||||
|
||||
def _add_columns(instance, table_name, columns_def):
|
||||
def _add_columns(
|
||||
session_maker: Callable[[], Session], table_name: str, columns_def: list[str]
|
||||
) -> None:
|
||||
"""Add columns to a table."""
|
||||
_LOGGER.warning(
|
||||
"Adding columns %s to table %s. Note: this can take several "
|
||||
|
@ -206,7 +220,7 @@ def _add_columns(instance, table_name, columns_def):
|
|||
|
||||
columns_def = [f"ADD {col_def}" for col_def in columns_def]
|
||||
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
|
@ -223,7 +237,7 @@ def _add_columns(instance, table_name, columns_def):
|
|||
_LOGGER.info("Unable to use quick column add. Adding 1 by 1")
|
||||
|
||||
for column_def in columns_def:
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
|
@ -242,7 +256,12 @@ def _add_columns(instance, table_name, columns_def):
|
|||
)
|
||||
|
||||
|
||||
def _modify_columns(instance, engine, table_name, columns_def):
|
||||
def _modify_columns(
|
||||
session_maker: Callable[[], Session],
|
||||
engine: Engine,
|
||||
table_name: str,
|
||||
columns_def: list[str],
|
||||
) -> None:
|
||||
"""Modify columns in a table."""
|
||||
if engine.dialect.name == "sqlite":
|
||||
_LOGGER.debug(
|
||||
|
@ -274,7 +293,7 @@ def _modify_columns(instance, engine, table_name, columns_def):
|
|||
else:
|
||||
columns_def = [f"MODIFY {col_def}" for col_def in columns_def]
|
||||
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
|
@ -289,7 +308,7 @@ def _modify_columns(instance, engine, table_name, columns_def):
|
|||
_LOGGER.info("Unable to use quick column modify. Modifying 1 by 1")
|
||||
|
||||
for column_def in columns_def:
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
|
@ -305,7 +324,9 @@ def _modify_columns(instance, engine, table_name, columns_def):
|
|||
)
|
||||
|
||||
|
||||
def _update_states_table_with_foreign_key_options(instance, engine):
|
||||
def _update_states_table_with_foreign_key_options(
|
||||
session_maker: Callable[[], Session], engine: Engine
|
||||
) -> None:
|
||||
"""Add the options to foreign key constraints."""
|
||||
inspector = sqlalchemy.inspect(engine)
|
||||
alters = []
|
||||
|
@ -333,7 +354,7 @@ def _update_states_table_with_foreign_key_options(instance, engine):
|
|||
)
|
||||
|
||||
for alter in alters:
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(DropConstraint(alter["old_fk"]))
|
||||
|
@ -346,7 +367,9 @@ def _update_states_table_with_foreign_key_options(instance, engine):
|
|||
)
|
||||
|
||||
|
||||
def _drop_foreign_key_constraints(instance, engine, table, columns):
|
||||
def _drop_foreign_key_constraints(
|
||||
session_maker: Callable[[], Session], engine: Engine, table: str, columns: list[str]
|
||||
) -> None:
|
||||
"""Drop foreign key constraints for a table on specific columns."""
|
||||
inspector = sqlalchemy.inspect(engine)
|
||||
drops = []
|
||||
|
@ -364,7 +387,7 @@ def _drop_foreign_key_constraints(instance, engine, table, columns):
|
|||
)
|
||||
|
||||
for drop in drops:
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
try:
|
||||
connection = session.connection()
|
||||
connection.execute(DropConstraint(drop))
|
||||
|
@ -376,19 +399,24 @@ def _drop_foreign_key_constraints(instance, engine, table, columns):
|
|||
)
|
||||
|
||||
|
||||
def _apply_update(instance, new_version, old_version): # noqa: C901
|
||||
def _apply_update( # noqa: C901
|
||||
hass: HomeAssistant,
|
||||
engine: Engine,
|
||||
session_maker: Callable[[], Session],
|
||||
new_version: int,
|
||||
old_version: int,
|
||||
) -> None:
|
||||
"""Perform operations to bring schema up to date."""
|
||||
engine = instance.engine
|
||||
dialect = engine.dialect.name
|
||||
big_int = "INTEGER(20)" if dialect == "mysql" else "INTEGER"
|
||||
|
||||
if new_version == 1:
|
||||
_create_index(instance, "events", "ix_events_time_fired")
|
||||
_create_index(session_maker, "events", "ix_events_time_fired")
|
||||
elif new_version == 2:
|
||||
# Create compound start/end index for recorder_runs
|
||||
_create_index(instance, "recorder_runs", "ix_recorder_runs_start_end")
|
||||
_create_index(session_maker, "recorder_runs", "ix_recorder_runs_start_end")
|
||||
# Create indexes for states
|
||||
_create_index(instance, "states", "ix_states_last_updated")
|
||||
_create_index(session_maker, "states", "ix_states_last_updated")
|
||||
elif new_version == 3:
|
||||
# There used to be a new index here, but it was removed in version 4.
|
||||
pass
|
||||
|
@ -398,41 +426,41 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
|
||||
if old_version == 3:
|
||||
# Remove index that was added in version 3
|
||||
_drop_index(instance, "states", "ix_states_created_domain")
|
||||
_drop_index(session_maker, "states", "ix_states_created_domain")
|
||||
if old_version == 2:
|
||||
# Remove index that was added in version 2
|
||||
_drop_index(instance, "states", "ix_states_entity_id_created")
|
||||
_drop_index(session_maker, "states", "ix_states_entity_id_created")
|
||||
|
||||
# Remove indexes that were added in version 0
|
||||
_drop_index(instance, "states", "states__state_changes")
|
||||
_drop_index(instance, "states", "states__significant_changes")
|
||||
_drop_index(instance, "states", "ix_states_entity_id_created")
|
||||
_drop_index(session_maker, "states", "states__state_changes")
|
||||
_drop_index(session_maker, "states", "states__significant_changes")
|
||||
_drop_index(session_maker, "states", "ix_states_entity_id_created")
|
||||
|
||||
_create_index(instance, "states", "ix_states_entity_id_last_updated")
|
||||
_create_index(session_maker, "states", "ix_states_entity_id_last_updated")
|
||||
elif new_version == 5:
|
||||
# Create supporting index for States.event_id foreign key
|
||||
_create_index(instance, "states", "ix_states_event_id")
|
||||
_create_index(session_maker, "states", "ix_states_event_id")
|
||||
elif new_version == 6:
|
||||
_add_columns(
|
||||
instance,
|
||||
session_maker,
|
||||
"events",
|
||||
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
|
||||
)
|
||||
_create_index(instance, "events", "ix_events_context_id")
|
||||
_create_index(instance, "events", "ix_events_context_user_id")
|
||||
_create_index(session_maker, "events", "ix_events_context_id")
|
||||
_create_index(session_maker, "events", "ix_events_context_user_id")
|
||||
_add_columns(
|
||||
instance,
|
||||
session_maker,
|
||||
"states",
|
||||
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
|
||||
)
|
||||
_create_index(instance, "states", "ix_states_context_id")
|
||||
_create_index(instance, "states", "ix_states_context_user_id")
|
||||
_create_index(session_maker, "states", "ix_states_context_id")
|
||||
_create_index(session_maker, "states", "ix_states_context_user_id")
|
||||
elif new_version == 7:
|
||||
_create_index(instance, "states", "ix_states_entity_id")
|
||||
_create_index(session_maker, "states", "ix_states_entity_id")
|
||||
elif new_version == 8:
|
||||
_add_columns(instance, "events", ["context_parent_id CHARACTER(36)"])
|
||||
_add_columns(instance, "states", ["old_state_id INTEGER"])
|
||||
_create_index(instance, "events", "ix_events_context_parent_id")
|
||||
_add_columns(session_maker, "events", ["context_parent_id CHARACTER(36)"])
|
||||
_add_columns(session_maker, "states", ["old_state_id INTEGER"])
|
||||
_create_index(session_maker, "events", "ix_events_context_parent_id")
|
||||
elif new_version == 9:
|
||||
# We now get the context from events with a join
|
||||
# since its always there on state_changed events
|
||||
|
@ -443,35 +471,35 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
# sqlalchemy alembic to make that work
|
||||
#
|
||||
# no longer dropping ix_states_context_id since its recreated in 28
|
||||
_drop_index(instance, "states", "ix_states_context_user_id")
|
||||
_drop_index(session_maker, "states", "ix_states_context_user_id")
|
||||
# This index won't be there if they were not running
|
||||
# nightly but we don't treat that as a critical issue
|
||||
_drop_index(instance, "states", "ix_states_context_parent_id")
|
||||
_drop_index(session_maker, "states", "ix_states_context_parent_id")
|
||||
# Redundant keys on composite index:
|
||||
# We already have ix_states_entity_id_last_updated
|
||||
_drop_index(instance, "states", "ix_states_entity_id")
|
||||
_create_index(instance, "events", "ix_events_event_type_time_fired")
|
||||
_drop_index(instance, "events", "ix_events_event_type")
|
||||
_drop_index(session_maker, "states", "ix_states_entity_id")
|
||||
_create_index(session_maker, "events", "ix_events_event_type_time_fired")
|
||||
_drop_index(session_maker, "events", "ix_events_event_type")
|
||||
elif new_version == 10:
|
||||
# Now done in step 11
|
||||
pass
|
||||
elif new_version == 11:
|
||||
_create_index(instance, "states", "ix_states_old_state_id")
|
||||
_update_states_table_with_foreign_key_options(instance, engine)
|
||||
_create_index(session_maker, "states", "ix_states_old_state_id")
|
||||
_update_states_table_with_foreign_key_options(session_maker, engine)
|
||||
elif new_version == 12:
|
||||
if engine.dialect.name == "mysql":
|
||||
_modify_columns(instance, engine, "events", ["event_data LONGTEXT"])
|
||||
_modify_columns(instance, engine, "states", ["attributes LONGTEXT"])
|
||||
_modify_columns(session_maker, engine, "events", ["event_data LONGTEXT"])
|
||||
_modify_columns(session_maker, engine, "states", ["attributes LONGTEXT"])
|
||||
elif new_version == 13:
|
||||
if engine.dialect.name == "mysql":
|
||||
_modify_columns(
|
||||
instance,
|
||||
session_maker,
|
||||
engine,
|
||||
"events",
|
||||
["time_fired DATETIME(6)", "created DATETIME(6)"],
|
||||
)
|
||||
_modify_columns(
|
||||
instance,
|
||||
session_maker,
|
||||
engine,
|
||||
"states",
|
||||
[
|
||||
|
@ -481,12 +509,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
],
|
||||
)
|
||||
elif new_version == 14:
|
||||
_modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
|
||||
_modify_columns(session_maker, engine, "events", ["event_type VARCHAR(64)"])
|
||||
elif new_version == 15:
|
||||
# This dropped the statistics table, done again in version 18.
|
||||
pass
|
||||
elif new_version == 16:
|
||||
_drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"])
|
||||
_drop_foreign_key_constraints(
|
||||
session_maker, engine, TABLE_STATES, ["old_state_id"]
|
||||
)
|
||||
elif new_version == 17:
|
||||
# This dropped the statistics table, done again in version 18.
|
||||
pass
|
||||
|
@ -511,13 +541,13 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
elif new_version == 19:
|
||||
# This adds the statistic runs table, insert a fake run to prevent duplicating
|
||||
# statistics.
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
session.add(StatisticsRuns(start=get_start_time()))
|
||||
elif new_version == 20:
|
||||
# This changed the precision of statistics from float to double
|
||||
if engine.dialect.name in ["mysql", "postgresql"]:
|
||||
_modify_columns(
|
||||
instance,
|
||||
session_maker,
|
||||
engine,
|
||||
"statistics",
|
||||
[
|
||||
|
@ -539,7 +569,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
table,
|
||||
)
|
||||
with contextlib.suppress(SQLAlchemyError):
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
# Using LOCK=EXCLUSIVE to prevent the database from corrupting
|
||||
|
@ -574,7 +604,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
# Block 5-minute statistics for one hour from the last run, or it will overlap
|
||||
# with existing hourly statistics. Don't block on a database with no existing
|
||||
# statistics.
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
if session.query(Statistics.id).count() and (
|
||||
last_run_string := session.query(
|
||||
func.max(StatisticsRuns.start)
|
||||
|
@ -590,7 +620,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
# When querying the database, be careful to only explicitly query for columns
|
||||
# which were present in schema version 21. If querying the table, SQLAlchemy
|
||||
# will refer to future columns.
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
for sum_statistic in session.query(StatisticsMeta.id).filter_by(
|
||||
has_sum=true()
|
||||
):
|
||||
|
@ -617,48 +647,52 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
)
|
||||
elif new_version == 23:
|
||||
# Add name column to StatisticsMeta
|
||||
_add_columns(instance, "statistics_meta", ["name VARCHAR(255)"])
|
||||
_add_columns(session_maker, "statistics_meta", ["name VARCHAR(255)"])
|
||||
elif new_version == 24:
|
||||
# Recreate statistics indices to block duplicated statistics
|
||||
_drop_index(instance, "statistics", "ix_statistics_statistic_id_start")
|
||||
_drop_index(session_maker, "statistics", "ix_statistics_statistic_id_start")
|
||||
_drop_index(
|
||||
instance,
|
||||
session_maker,
|
||||
"statistics_short_term",
|
||||
"ix_statistics_short_term_statistic_id_start",
|
||||
)
|
||||
try:
|
||||
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
|
||||
_create_index(
|
||||
instance,
|
||||
session_maker, "statistics", "ix_statistics_statistic_id_start"
|
||||
)
|
||||
_create_index(
|
||||
session_maker,
|
||||
"statistics_short_term",
|
||||
"ix_statistics_short_term_statistic_id_start",
|
||||
)
|
||||
except DatabaseError:
|
||||
# There may be duplicated statistics entries, delete duplicated statistics
|
||||
# and try again
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
delete_duplicates(instance, session)
|
||||
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
|
||||
with session_scope(session=session_maker()) as session:
|
||||
delete_duplicates(hass, session)
|
||||
_create_index(
|
||||
instance,
|
||||
session_maker, "statistics", "ix_statistics_statistic_id_start"
|
||||
)
|
||||
_create_index(
|
||||
session_maker,
|
||||
"statistics_short_term",
|
||||
"ix_statistics_short_term_statistic_id_start",
|
||||
)
|
||||
elif new_version == 25:
|
||||
_add_columns(instance, "states", [f"attributes_id {big_int}"])
|
||||
_create_index(instance, "states", "ix_states_attributes_id")
|
||||
_add_columns(session_maker, "states", [f"attributes_id {big_int}"])
|
||||
_create_index(session_maker, "states", "ix_states_attributes_id")
|
||||
elif new_version == 26:
|
||||
_create_index(instance, "statistics_runs", "ix_statistics_runs_start")
|
||||
_create_index(session_maker, "statistics_runs", "ix_statistics_runs_start")
|
||||
elif new_version == 27:
|
||||
_add_columns(instance, "events", [f"data_id {big_int}"])
|
||||
_create_index(instance, "events", "ix_events_data_id")
|
||||
_add_columns(session_maker, "events", [f"data_id {big_int}"])
|
||||
_create_index(session_maker, "events", "ix_events_data_id")
|
||||
elif new_version == 28:
|
||||
_add_columns(instance, "events", ["origin_idx INTEGER"])
|
||||
_add_columns(session_maker, "events", ["origin_idx INTEGER"])
|
||||
# We never use the user_id or parent_id index
|
||||
_drop_index(instance, "events", "ix_events_context_user_id")
|
||||
_drop_index(instance, "events", "ix_events_context_parent_id")
|
||||
_drop_index(session_maker, "events", "ix_events_context_user_id")
|
||||
_drop_index(session_maker, "events", "ix_events_context_parent_id")
|
||||
_add_columns(
|
||||
instance,
|
||||
session_maker,
|
||||
"states",
|
||||
[
|
||||
"origin_idx INTEGER",
|
||||
|
@ -667,14 +701,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
|
|||
"context_parent_id VARCHAR(36)",
|
||||
],
|
||||
)
|
||||
_create_index(instance, "states", "ix_states_context_id")
|
||||
_create_index(session_maker, "states", "ix_states_context_id")
|
||||
# Once there are no longer any state_changed events
|
||||
# in the events table we can drop the index on states.event_id
|
||||
else:
|
||||
raise ValueError(f"No schema migration defined for version {new_version}")
|
||||
|
||||
|
||||
def _inspect_schema_version(session):
|
||||
def _inspect_schema_version(session: Session) -> int:
|
||||
"""Determine the schema version by inspecting the db structure.
|
||||
|
||||
When the schema version is not present in the db, either db was just
|
||||
|
@ -696,4 +730,4 @@ def _inspect_schema_version(session):
|
|||
# Version 1 schema changes not found, this db needs to be migrated.
|
||||
current_version = SchemaChanges(schema_version=0)
|
||||
session.add(current_version)
|
||||
return current_version.schema_version
|
||||
return cast(int, current_version.schema_version)
|
||||
|
|
|
@ -47,7 +47,7 @@ def purge_old_data(
|
|||
)
|
||||
using_sqlite = instance.using_sqlite()
|
||||
|
||||
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
# Purge a max of MAX_ROWS_TO_PURGE, based on the oldest states or events record
|
||||
(
|
||||
event_ids,
|
||||
|
@ -515,7 +515,7 @@ def _purge_filtered_events(
|
|||
def purge_entity_data(instance: Recorder, entity_filter: Callable[[str], bool]) -> bool:
|
||||
"""Purge states and events of specified entities."""
|
||||
using_sqlite = instance.using_sqlite()
|
||||
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
selected_entity_ids: list[str] = [
|
||||
entity_id
|
||||
for (entity_id,) in session.query(distinct(States.entity_id)).all()
|
||||
|
|
|
@ -377,7 +377,7 @@ def _delete_duplicates_from_table(
|
|||
return (total_deleted_rows, all_non_identical_duplicates)
|
||||
|
||||
|
||||
def delete_duplicates(instance: Recorder, session: Session) -> None:
|
||||
def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
|
||||
"""Identify and delete duplicated statistics.
|
||||
|
||||
A backup will be made of duplicated statistics before it is deleted.
|
||||
|
@ -391,7 +391,7 @@ def delete_duplicates(instance: Recorder, session: Session) -> None:
|
|||
if non_identical_duplicates:
|
||||
isotime = dt_util.utcnow().isoformat()
|
||||
backup_file_name = f"deleted_statistics.{isotime}.json"
|
||||
backup_path = instance.hass.config.path(STORAGE_DIR, backup_file_name)
|
||||
backup_path = hass.config.path(STORAGE_DIR, backup_file_name)
|
||||
|
||||
os.makedirs(os.path.dirname(backup_path), exist_ok=True)
|
||||
with open(backup_path, "w", encoding="utf8") as backup_file:
|
||||
|
@ -551,7 +551,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
|
|||
end = start + timedelta(minutes=5)
|
||||
|
||||
# Return if we already have 5-minute statistics for the requested period
|
||||
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
if session.query(StatisticsRuns).filter_by(start=start).first():
|
||||
_LOGGER.debug("Statistics already compiled for %s-%s", start, end)
|
||||
return True
|
||||
|
@ -578,7 +578,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
|
|||
|
||||
# Insert collected statistics in the database
|
||||
with session_scope(
|
||||
session=instance.get_session(), # type: ignore[misc]
|
||||
session=instance.get_session(),
|
||||
exception_filter=_filter_unique_constraint_integrity_error(instance),
|
||||
) as session:
|
||||
for stats in platform_stats:
|
||||
|
@ -768,7 +768,7 @@ def _configured_unit(unit: str | None, units: UnitSystem) -> str | None:
|
|||
|
||||
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
|
||||
"""Clear statistics for a list of statistic_ids."""
|
||||
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
session.query(StatisticsMeta).filter(
|
||||
StatisticsMeta.statistic_id.in_(statistic_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
@ -778,7 +778,7 @@ def update_statistics_metadata(
|
|||
instance: Recorder, statistic_id: str, unit_of_measurement: str | None
|
||||
) -> None:
|
||||
"""Update statistics metadata for a statistic_id."""
|
||||
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
session.query(StatisticsMeta).filter(
|
||||
StatisticsMeta.statistic_id == statistic_id
|
||||
).update({StatisticsMeta.unit_of_measurement: unit_of_measurement})
|
||||
|
@ -1376,7 +1376,7 @@ def add_external_statistics(
|
|||
"""Process an add_external_statistics job."""
|
||||
|
||||
with session_scope(
|
||||
session=instance.get_session(), # type: ignore[misc]
|
||||
session=instance.get_session(),
|
||||
exception_filter=_filter_unique_constraint_integrity_error(instance),
|
||||
) as session:
|
||||
old_metadata_dict = get_metadata_with_session(
|
||||
|
@ -1403,7 +1403,7 @@ def adjust_statistics(
|
|||
) -> bool:
|
||||
"""Process an add_statistics job."""
|
||||
|
||||
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
metadata = get_metadata_with_session(
|
||||
instance.hass, session, statistic_ids=(statistic_id,)
|
||||
)
|
||||
|
|
|
@ -65,8 +65,6 @@ class PurgeTask(RecorderTask):
|
|||
|
||||
def run(self, instance: Recorder) -> None:
|
||||
"""Purge the database."""
|
||||
assert instance.get_session is not None
|
||||
|
||||
if purge.purge_old_data(
|
||||
instance, self.purge_before, self.repack, self.apply_filter
|
||||
):
|
||||
|
|
178
mypy.ini
178
mypy.ini
|
@ -1710,183 +1710,7 @@ no_implicit_optional = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.const]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.core]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.backup]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.executor]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.history]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.models]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.pool]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.purge]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.repack]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.run_history]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.services]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.statistics]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.system_health]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.tasks]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.util]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.websocket_api]
|
||||
[mypy-homeassistant.components.recorder.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
|
|
|
@ -138,6 +138,8 @@ async def test_shutdown_closes_connections(hass, recorder_mock):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert len(pool.shutdown.mock_calls) == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
assert instance.get_session()
|
||||
|
||||
|
||||
async def test_state_gets_saved_when_set_before_start_event(
|
||||
|
|
|
@ -60,9 +60,12 @@ async def test_schema_update_calls(hass):
|
|||
await async_wait_recording_done(hass)
|
||||
|
||||
assert recorder.util.async_migration_in_progress(hass) is False
|
||||
instance = recorder.get_instance(hass)
|
||||
engine = instance.engine
|
||||
session_maker = instance.get_session
|
||||
update.assert_has_calls(
|
||||
[
|
||||
call(hass.data[DATA_INSTANCE], version + 1, 0)
|
||||
call(hass, engine, session_maker, version + 1, 0)
|
||||
for version in range(0, models.SCHEMA_VERSION)
|
||||
]
|
||||
)
|
||||
|
@ -327,10 +330,10 @@ async def test_schema_migrate(hass, start_version):
|
|||
assert recorder.util.async_migration_in_progress(hass) is not True
|
||||
|
||||
|
||||
def test_invalid_update():
|
||||
def test_invalid_update(hass):
|
||||
"""Test that an invalid new version raises an exception."""
|
||||
with pytest.raises(ValueError):
|
||||
migration._apply_update(Mock(), -1, 0)
|
||||
migration._apply_update(hass, Mock(), Mock(), -1, 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -351,7 +354,9 @@ def test_modify_column(engine_type, substr):
|
|||
instance.get_session = Mock(return_value=session)
|
||||
engine = Mock()
|
||||
engine.dialect.name = engine_type
|
||||
migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
|
||||
migration._modify_columns(
|
||||
instance.get_session, engine, "events", ["event_type VARCHAR(64)"]
|
||||
)
|
||||
if substr:
|
||||
assert substr in connection.execute.call_args[0][0].text
|
||||
else:
|
||||
|
@ -365,8 +370,12 @@ def test_forgiving_add_column():
|
|||
session.execute(text("CREATE TABLE hello (id int)"))
|
||||
instance = Mock()
|
||||
instance.get_session = Mock(return_value=session)
|
||||
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
|
||||
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
|
||||
migration._add_columns(
|
||||
instance.get_session, "hello", ["context_id CHARACTER(36)"]
|
||||
)
|
||||
migration._add_columns(
|
||||
instance.get_session, "hello", ["context_id CHARACTER(36)"]
|
||||
)
|
||||
|
||||
|
||||
def test_forgiving_add_index():
|
||||
|
@ -376,7 +385,7 @@ def test_forgiving_add_index():
|
|||
with Session(engine) as session:
|
||||
instance = Mock()
|
||||
instance.get_session = Mock(return_value=session)
|
||||
migration._create_index(instance, "states", "ix_states_context_id")
|
||||
migration._create_index(instance.get_session, "states", "ix_states_context_id")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -740,7 +740,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog):
|
|||
hass = hass_recorder()
|
||||
wait_recording_done(hass)
|
||||
with session_scope(hass=hass) as session:
|
||||
delete_duplicates(hass.data[DATA_INSTANCE], session)
|
||||
delete_duplicates(hass, session)
|
||||
assert "duplicated statistics rows" not in caplog.text
|
||||
assert "Found non identical" not in caplog.text
|
||||
assert "Found duplicated" not in caplog.text
|
||||
|
|
Loading…
Add table
Reference in a new issue