Complete strict typing for recorder (#71274)

* Complete strict typing for recorder

* update tests

* Update tests/components/recorder/test_migrate.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Update tests/components/recorder/test_migrate.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove the asserts

* remove ignore comments

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
J. Nick Koston 2022-05-04 12:22:50 -05:00 committed by GitHub
parent 13ce0a7d6a
commit eb77f8db85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 309 deletions

View file

@ -177,23 +177,7 @@ homeassistant.components.pure_energie.*
homeassistant.components.rainmachine.* homeassistant.components.rainmachine.*
homeassistant.components.rdw.* homeassistant.components.rdw.*
homeassistant.components.recollect_waste.* homeassistant.components.recollect_waste.*
homeassistant.components.recorder homeassistant.components.recorder.*
homeassistant.components.recorder.const
homeassistant.components.recorder.core
homeassistant.components.recorder.backup
homeassistant.components.recorder.executor
homeassistant.components.recorder.history
homeassistant.components.recorder.models
homeassistant.components.recorder.pool
homeassistant.components.recorder.purge
homeassistant.components.recorder.repack
homeassistant.components.recorder.run_history
homeassistant.components.recorder.services
homeassistant.components.recorder.statistics
homeassistant.components.recorder.system_health
homeassistant.components.recorder.tasks
homeassistant.components.recorder.util
homeassistant.components.recorder.websocket_api
homeassistant.components.remote.* homeassistant.components.remote.*
homeassistant.components.renault.* homeassistant.components.renault.*
homeassistant.components.ridwell.* homeassistant.components.ridwell.*

View file

@ -171,7 +171,7 @@ class Recorder(threading.Thread):
self._pending_event_data: dict[str, EventData] = {} self._pending_event_data: dict[str, EventData] = {}
self._pending_expunge: list[States] = [] self._pending_expunge: list[States] = []
self.event_session: Session | None = None self.event_session: Session | None = None
self.get_session: Callable[[], Session] | None = None self._get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None self._completed_first_database_setup: bool | None = None
self.async_migration_event = asyncio.Event() self.async_migration_event = asyncio.Event()
self.migration_in_progress = False self.migration_in_progress = False
@ -205,6 +205,12 @@ class Recorder(threading.Thread):
"""Return if the recorder is recording.""" """Return if the recorder is recording."""
return self._event_listener is not None return self._event_listener is not None
def get_session(self) -> Session:
"""Get a new sqlalchemy session."""
if self._get_session is None:
raise RuntimeError("The database connection has not been established")
return self._get_session()
def queue_task(self, task: RecorderTask) -> None: def queue_task(self, task: RecorderTask) -> None:
"""Add a task to the recorder queue.""" """Add a task to the recorder queue."""
self._queue.put(task) self._queue.put(task)
@ -459,7 +465,7 @@ class Recorder(threading.Thread):
@callback @callback
def _async_setup_periodic_tasks(self) -> None: def _async_setup_periodic_tasks(self) -> None:
"""Prepare periodic tasks.""" """Prepare periodic tasks."""
if self.hass.is_stopping or not self.get_session: if self.hass.is_stopping or not self._get_session:
# Home Assistant is shutting down # Home Assistant is shutting down
return return
@ -591,7 +597,7 @@ class Recorder(threading.Thread):
while tries <= self.db_max_retries: while tries <= self.db_max_retries:
try: try:
self._setup_connection() self._setup_connection()
return migration.get_schema_version(self) return migration.get_schema_version(self.get_session)
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.exception( _LOGGER.exception(
"Error during connection setup: %s (retrying in %s seconds)", "Error during connection setup: %s (retrying in %s seconds)",
@ -619,7 +625,9 @@ class Recorder(threading.Thread):
self.hass.add_job(self._async_migration_started) self.hass.add_job(self._async_migration_started)
try: try:
migration.migrate_schema(self, current_version) migration.migrate_schema(
self.hass, self.engine, self.get_session, current_version
)
except exc.DatabaseError as err: except exc.DatabaseError as err:
if self._handle_database_error(err): if self._handle_database_error(err):
return True return True
@ -896,7 +904,6 @@ class Recorder(threading.Thread):
def _open_event_session(self) -> None: def _open_event_session(self) -> None:
"""Open the event session.""" """Open the event session."""
assert self.get_session is not None
self.event_session = self.get_session() self.event_session = self.get_session()
self.event_session.expire_on_commit = False self.event_session.expire_on_commit = False
@ -1011,7 +1018,7 @@ class Recorder(threading.Thread):
sqlalchemy_event.listen(self.engine, "connect", setup_recorder_connection) sqlalchemy_event.listen(self.engine, "connect", setup_recorder_connection)
Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)
self.get_session = scoped_session(sessionmaker(bind=self.engine, future=True)) self._get_session = scoped_session(sessionmaker(bind=self.engine, future=True))
_LOGGER.debug("Connected to recorder database") _LOGGER.debug("Connected to recorder database")
def _close_connection(self) -> None: def _close_connection(self) -> None:
@ -1019,11 +1026,10 @@ class Recorder(threading.Thread):
assert self.engine is not None assert self.engine is not None
self.engine.dispose() self.engine.dispose()
self.engine = None self.engine = None
self.get_session = None self._get_session = None
def _setup_run(self) -> None: def _setup_run(self) -> None:
"""Log the start of the current run and schedule any needed jobs.""" """Log the start of the current run and schedule any needed jobs."""
assert self.get_session is not None
with session_scope(session=self.get_session()) as session: with session_scope(session=self.get_session()) as session:
end_incomplete_runs(session, self.run_history.recording_start) end_incomplete_runs(session, self.run_history.recording_start)
self.run_history.start(session) self.run_history.start(session)

View file

@ -1,11 +1,13 @@
"""Schema migration helpers.""" """Schema migration helpers."""
from collections.abc import Callable, Iterable
import contextlib import contextlib
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any from typing import cast
import sqlalchemy import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ( from sqlalchemy.exc import (
DatabaseError, DatabaseError,
InternalError, InternalError,
@ -13,9 +15,12 @@ from sqlalchemy.exc import (
ProgrammingError, ProgrammingError,
SQLAlchemyError, SQLAlchemyError,
) )
from sqlalchemy.orm.session import Session
from sqlalchemy.schema import AddConstraint, DropConstraint from sqlalchemy.schema import AddConstraint, DropConstraint
from sqlalchemy.sql.expression import true from sqlalchemy.sql.expression import true
from homeassistant.core import HomeAssistant
from .models import ( from .models import (
SCHEMA_VERSION, SCHEMA_VERSION,
TABLE_STATES, TABLE_STATES,
@ -33,7 +38,7 @@ from .util import session_scope
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def raise_if_exception_missing_str(ex, match_substrs): def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) -> None:
"""Raise an exception if the exception and cause do not contain the match substrs.""" """Raise an exception if the exception and cause do not contain the match substrs."""
lower_ex_strs = [str(ex).lower(), str(ex.__cause__).lower()] lower_ex_strs = [str(ex).lower(), str(ex.__cause__).lower()]
for str_sub in match_substrs: for str_sub in match_substrs:
@ -44,10 +49,9 @@ def raise_if_exception_missing_str(ex, match_substrs):
raise ex raise ex
def get_schema_version(instance: Any) -> int: def get_schema_version(session_maker: Callable[[], Session]) -> int:
"""Get the schema version.""" """Get the schema version."""
assert instance.get_session is not None with session_scope(session=session_maker()) as session:
with session_scope(session=instance.get_session()) as session:
res = ( res = (
session.query(SchemaChanges) session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc()) .order_by(SchemaChanges.change_id.desc())
@ -61,7 +65,7 @@ def get_schema_version(instance: Any) -> int:
"No schema version found. Inspected version: %s", current_version "No schema version found. Inspected version: %s", current_version
) )
return current_version return cast(int, current_version)
def schema_is_current(current_version: int) -> bool: def schema_is_current(current_version: int) -> bool:
@ -69,21 +73,27 @@ def schema_is_current(current_version: int) -> bool:
return current_version == SCHEMA_VERSION return current_version == SCHEMA_VERSION
def migrate_schema(instance: Any, current_version: int) -> None: def migrate_schema(
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
current_version: int,
) -> None:
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
assert instance.get_session is not None
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
for version in range(current_version, SCHEMA_VERSION): for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1 new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version) _LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(instance, new_version, current_version) _apply_update(hass, engine, session_maker, new_version, current_version)
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
session.add(SchemaChanges(schema_version=new_version)) session.add(SchemaChanges(schema_version=new_version))
_LOGGER.info("Upgrade to version %s done", new_version) _LOGGER.info("Upgrade to version %s done", new_version)
def _create_index(instance, table_name, index_name): def _create_index(
session_maker: Callable[[], Session], table_name: str, index_name: str
) -> None:
"""Create an index for the specified table. """Create an index for the specified table.
The index name should match the name given for the index The index name should match the name given for the index
@ -104,7 +114,7 @@ def _create_index(instance, table_name, index_name):
"be patient!", "be patient!",
index_name, index_name,
) )
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
index.create(connection) index.create(connection)
@ -117,7 +127,9 @@ def _create_index(instance, table_name, index_name):
_LOGGER.debug("Finished creating %s", index_name) _LOGGER.debug("Finished creating %s", index_name)
def _drop_index(instance, table_name, index_name): def _drop_index(
session_maker: Callable[[], Session], table_name: str, index_name: str
) -> None:
"""Drop an index from a specified table. """Drop an index from a specified table.
There is no universal way to do something like `DROP INDEX IF EXISTS` There is no universal way to do something like `DROP INDEX IF EXISTS`
@ -132,7 +144,7 @@ def _drop_index(instance, table_name, index_name):
success = False success = False
# Engines like DB2/Oracle # Engines like DB2/Oracle
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute(text(f"DROP INDEX {index_name}")) connection.execute(text(f"DROP INDEX {index_name}"))
@ -143,7 +155,7 @@ def _drop_index(instance, table_name, index_name):
# Engines like SQLite, SQL Server # Engines like SQLite, SQL Server
if not success: if not success:
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
@ -160,7 +172,7 @@ def _drop_index(instance, table_name, index_name):
if not success: if not success:
# Engines like MySQL, MS Access # Engines like MySQL, MS Access
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
@ -194,7 +206,9 @@ def _drop_index(instance, table_name, index_name):
) )
def _add_columns(instance, table_name, columns_def): def _add_columns(
session_maker: Callable[[], Session], table_name: str, columns_def: list[str]
) -> None:
"""Add columns to a table.""" """Add columns to a table."""
_LOGGER.warning( _LOGGER.warning(
"Adding columns %s to table %s. Note: this can take several " "Adding columns %s to table %s. Note: this can take several "
@ -206,7 +220,7 @@ def _add_columns(instance, table_name, columns_def):
columns_def = [f"ADD {col_def}" for col_def in columns_def] columns_def = [f"ADD {col_def}" for col_def in columns_def]
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
@ -223,7 +237,7 @@ def _add_columns(instance, table_name, columns_def):
_LOGGER.info("Unable to use quick column add. Adding 1 by 1") _LOGGER.info("Unable to use quick column add. Adding 1 by 1")
for column_def in columns_def: for column_def in columns_def:
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
@ -242,7 +256,12 @@ def _add_columns(instance, table_name, columns_def):
) )
def _modify_columns(instance, engine, table_name, columns_def): def _modify_columns(
session_maker: Callable[[], Session],
engine: Engine,
table_name: str,
columns_def: list[str],
) -> None:
"""Modify columns in a table.""" """Modify columns in a table."""
if engine.dialect.name == "sqlite": if engine.dialect.name == "sqlite":
_LOGGER.debug( _LOGGER.debug(
@ -274,7 +293,7 @@ def _modify_columns(instance, engine, table_name, columns_def):
else: else:
columns_def = [f"MODIFY {col_def}" for col_def in columns_def] columns_def = [f"MODIFY {col_def}" for col_def in columns_def]
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
@ -289,7 +308,7 @@ def _modify_columns(instance, engine, table_name, columns_def):
_LOGGER.info("Unable to use quick column modify. Modifying 1 by 1") _LOGGER.info("Unable to use quick column modify. Modifying 1 by 1")
for column_def in columns_def: for column_def in columns_def:
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
@ -305,7 +324,9 @@ def _modify_columns(instance, engine, table_name, columns_def):
) )
def _update_states_table_with_foreign_key_options(instance, engine): def _update_states_table_with_foreign_key_options(
session_maker: Callable[[], Session], engine: Engine
) -> None:
"""Add the options to foreign key constraints.""" """Add the options to foreign key constraints."""
inspector = sqlalchemy.inspect(engine) inspector = sqlalchemy.inspect(engine)
alters = [] alters = []
@ -333,7 +354,7 @@ def _update_states_table_with_foreign_key_options(instance, engine):
) )
for alter in alters: for alter in alters:
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute(DropConstraint(alter["old_fk"])) connection.execute(DropConstraint(alter["old_fk"]))
@ -346,7 +367,9 @@ def _update_states_table_with_foreign_key_options(instance, engine):
) )
def _drop_foreign_key_constraints(instance, engine, table, columns): def _drop_foreign_key_constraints(
session_maker: Callable[[], Session], engine: Engine, table: str, columns: list[str]
) -> None:
"""Drop foreign key constraints for a table on specific columns.""" """Drop foreign key constraints for a table on specific columns."""
inspector = sqlalchemy.inspect(engine) inspector = sqlalchemy.inspect(engine)
drops = [] drops = []
@ -364,7 +387,7 @@ def _drop_foreign_key_constraints(instance, engine, table, columns):
) )
for drop in drops: for drop in drops:
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute(DropConstraint(drop)) connection.execute(DropConstraint(drop))
@ -376,19 +399,24 @@ def _drop_foreign_key_constraints(instance, engine, table, columns):
) )
def _apply_update(instance, new_version, old_version): # noqa: C901 def _apply_update( # noqa: C901
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
new_version: int,
old_version: int,
) -> None:
"""Perform operations to bring schema up to date.""" """Perform operations to bring schema up to date."""
engine = instance.engine
dialect = engine.dialect.name dialect = engine.dialect.name
big_int = "INTEGER(20)" if dialect == "mysql" else "INTEGER" big_int = "INTEGER(20)" if dialect == "mysql" else "INTEGER"
if new_version == 1: if new_version == 1:
_create_index(instance, "events", "ix_events_time_fired") _create_index(session_maker, "events", "ix_events_time_fired")
elif new_version == 2: elif new_version == 2:
# Create compound start/end index for recorder_runs # Create compound start/end index for recorder_runs
_create_index(instance, "recorder_runs", "ix_recorder_runs_start_end") _create_index(session_maker, "recorder_runs", "ix_recorder_runs_start_end")
# Create indexes for states # Create indexes for states
_create_index(instance, "states", "ix_states_last_updated") _create_index(session_maker, "states", "ix_states_last_updated")
elif new_version == 3: elif new_version == 3:
# There used to be a new index here, but it was removed in version 4. # There used to be a new index here, but it was removed in version 4.
pass pass
@ -398,41 +426,41 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
if old_version == 3: if old_version == 3:
# Remove index that was added in version 3 # Remove index that was added in version 3
_drop_index(instance, "states", "ix_states_created_domain") _drop_index(session_maker, "states", "ix_states_created_domain")
if old_version == 2: if old_version == 2:
# Remove index that was added in version 2 # Remove index that was added in version 2
_drop_index(instance, "states", "ix_states_entity_id_created") _drop_index(session_maker, "states", "ix_states_entity_id_created")
# Remove indexes that were added in version 0 # Remove indexes that were added in version 0
_drop_index(instance, "states", "states__state_changes") _drop_index(session_maker, "states", "states__state_changes")
_drop_index(instance, "states", "states__significant_changes") _drop_index(session_maker, "states", "states__significant_changes")
_drop_index(instance, "states", "ix_states_entity_id_created") _drop_index(session_maker, "states", "ix_states_entity_id_created")
_create_index(instance, "states", "ix_states_entity_id_last_updated") _create_index(session_maker, "states", "ix_states_entity_id_last_updated")
elif new_version == 5: elif new_version == 5:
# Create supporting index for States.event_id foreign key # Create supporting index for States.event_id foreign key
_create_index(instance, "states", "ix_states_event_id") _create_index(session_maker, "states", "ix_states_event_id")
elif new_version == 6: elif new_version == 6:
_add_columns( _add_columns(
instance, session_maker,
"events", "events",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
) )
_create_index(instance, "events", "ix_events_context_id") _create_index(session_maker, "events", "ix_events_context_id")
_create_index(instance, "events", "ix_events_context_user_id") _create_index(session_maker, "events", "ix_events_context_user_id")
_add_columns( _add_columns(
instance, session_maker,
"states", "states",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
) )
_create_index(instance, "states", "ix_states_context_id") _create_index(session_maker, "states", "ix_states_context_id")
_create_index(instance, "states", "ix_states_context_user_id") _create_index(session_maker, "states", "ix_states_context_user_id")
elif new_version == 7: elif new_version == 7:
_create_index(instance, "states", "ix_states_entity_id") _create_index(session_maker, "states", "ix_states_entity_id")
elif new_version == 8: elif new_version == 8:
_add_columns(instance, "events", ["context_parent_id CHARACTER(36)"]) _add_columns(session_maker, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(instance, "states", ["old_state_id INTEGER"]) _add_columns(session_maker, "states", ["old_state_id INTEGER"])
_create_index(instance, "events", "ix_events_context_parent_id") _create_index(session_maker, "events", "ix_events_context_parent_id")
elif new_version == 9: elif new_version == 9:
# We now get the context from events with a join # We now get the context from events with a join
# since its always there on state_changed events # since its always there on state_changed events
@ -443,35 +471,35 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# sqlalchemy alembic to make that work # sqlalchemy alembic to make that work
# #
# no longer dropping ix_states_context_id since its recreated in 28 # no longer dropping ix_states_context_id since its recreated in 28
_drop_index(instance, "states", "ix_states_context_user_id") _drop_index(session_maker, "states", "ix_states_context_user_id")
# This index won't be there if they were not running # This index won't be there if they were not running
# nightly but we don't treat that as a critical issue # nightly but we don't treat that as a critical issue
_drop_index(instance, "states", "ix_states_context_parent_id") _drop_index(session_maker, "states", "ix_states_context_parent_id")
# Redundant keys on composite index: # Redundant keys on composite index:
# We already have ix_states_entity_id_last_updated # We already have ix_states_entity_id_last_updated
_drop_index(instance, "states", "ix_states_entity_id") _drop_index(session_maker, "states", "ix_states_entity_id")
_create_index(instance, "events", "ix_events_event_type_time_fired") _create_index(session_maker, "events", "ix_events_event_type_time_fired")
_drop_index(instance, "events", "ix_events_event_type") _drop_index(session_maker, "events", "ix_events_event_type")
elif new_version == 10: elif new_version == 10:
# Now done in step 11 # Now done in step 11
pass pass
elif new_version == 11: elif new_version == 11:
_create_index(instance, "states", "ix_states_old_state_id") _create_index(session_maker, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(instance, engine) _update_states_table_with_foreign_key_options(session_maker, engine)
elif new_version == 12: elif new_version == 12:
if engine.dialect.name == "mysql": if engine.dialect.name == "mysql":
_modify_columns(instance, engine, "events", ["event_data LONGTEXT"]) _modify_columns(session_maker, engine, "events", ["event_data LONGTEXT"])
_modify_columns(instance, engine, "states", ["attributes LONGTEXT"]) _modify_columns(session_maker, engine, "states", ["attributes LONGTEXT"])
elif new_version == 13: elif new_version == 13:
if engine.dialect.name == "mysql": if engine.dialect.name == "mysql":
_modify_columns( _modify_columns(
instance, session_maker,
engine, engine,
"events", "events",
["time_fired DATETIME(6)", "created DATETIME(6)"], ["time_fired DATETIME(6)", "created DATETIME(6)"],
) )
_modify_columns( _modify_columns(
instance, session_maker,
engine, engine,
"states", "states",
[ [
@ -481,12 +509,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
], ],
) )
elif new_version == 14: elif new_version == 14:
_modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"]) _modify_columns(session_maker, engine, "events", ["event_type VARCHAR(64)"])
elif new_version == 15: elif new_version == 15:
# This dropped the statistics table, done again in version 18. # This dropped the statistics table, done again in version 18.
pass pass
elif new_version == 16: elif new_version == 16:
_drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"]) _drop_foreign_key_constraints(
session_maker, engine, TABLE_STATES, ["old_state_id"]
)
elif new_version == 17: elif new_version == 17:
# This dropped the statistics table, done again in version 18. # This dropped the statistics table, done again in version 18.
pass pass
@ -511,13 +541,13 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
elif new_version == 19: elif new_version == 19:
# This adds the statistic runs table, insert a fake run to prevent duplicating # This adds the statistic runs table, insert a fake run to prevent duplicating
# statistics. # statistics.
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
session.add(StatisticsRuns(start=get_start_time())) session.add(StatisticsRuns(start=get_start_time()))
elif new_version == 20: elif new_version == 20:
# This changed the precision of statistics from float to double # This changed the precision of statistics from float to double
if engine.dialect.name in ["mysql", "postgresql"]: if engine.dialect.name in ["mysql", "postgresql"]:
_modify_columns( _modify_columns(
instance, session_maker,
engine, engine,
"statistics", "statistics",
[ [
@ -539,7 +569,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
table, table,
) )
with contextlib.suppress(SQLAlchemyError): with contextlib.suppress(SQLAlchemyError):
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
connection = session.connection() connection = session.connection()
connection.execute( connection.execute(
# Using LOCK=EXCLUSIVE to prevent the database from corrupting # Using LOCK=EXCLUSIVE to prevent the database from corrupting
@ -574,7 +604,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# Block 5-minute statistics for one hour from the last run, or it will overlap # Block 5-minute statistics for one hour from the last run, or it will overlap
# with existing hourly statistics. Don't block on a database with no existing # with existing hourly statistics. Don't block on a database with no existing
# statistics. # statistics.
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
if session.query(Statistics.id).count() and ( if session.query(Statistics.id).count() and (
last_run_string := session.query( last_run_string := session.query(
func.max(StatisticsRuns.start) func.max(StatisticsRuns.start)
@ -590,7 +620,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# When querying the database, be careful to only explicitly query for columns # When querying the database, be careful to only explicitly query for columns
# which were present in schema version 21. If querying the table, SQLAlchemy # which were present in schema version 21. If querying the table, SQLAlchemy
# will refer to future columns. # will refer to future columns.
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
for sum_statistic in session.query(StatisticsMeta.id).filter_by( for sum_statistic in session.query(StatisticsMeta.id).filter_by(
has_sum=true() has_sum=true()
): ):
@ -617,48 +647,52 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
) )
elif new_version == 23: elif new_version == 23:
# Add name column to StatisticsMeta # Add name column to StatisticsMeta
_add_columns(instance, "statistics_meta", ["name VARCHAR(255)"]) _add_columns(session_maker, "statistics_meta", ["name VARCHAR(255)"])
elif new_version == 24: elif new_version == 24:
# Recreate statistics indices to block duplicated statistics # Recreate statistics indices to block duplicated statistics
_drop_index(instance, "statistics", "ix_statistics_statistic_id_start") _drop_index(session_maker, "statistics", "ix_statistics_statistic_id_start")
_drop_index( _drop_index(
instance, session_maker,
"statistics_short_term", "statistics_short_term",
"ix_statistics_short_term_statistic_id_start", "ix_statistics_short_term_statistic_id_start",
) )
try: try:
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
_create_index( _create_index(
instance, session_maker, "statistics", "ix_statistics_statistic_id_start"
)
_create_index(
session_maker,
"statistics_short_term", "statistics_short_term",
"ix_statistics_short_term_statistic_id_start", "ix_statistics_short_term_statistic_id_start",
) )
except DatabaseError: except DatabaseError:
# There may be duplicated statistics entries, delete duplicated statistics # There may be duplicated statistics entries, delete duplicated statistics
# and try again # and try again
with session_scope(session=instance.get_session()) as session: with session_scope(session=session_maker()) as session:
delete_duplicates(instance, session) delete_duplicates(hass, session)
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
_create_index( _create_index(
instance, session_maker, "statistics", "ix_statistics_statistic_id_start"
)
_create_index(
session_maker,
"statistics_short_term", "statistics_short_term",
"ix_statistics_short_term_statistic_id_start", "ix_statistics_short_term_statistic_id_start",
) )
elif new_version == 25: elif new_version == 25:
_add_columns(instance, "states", [f"attributes_id {big_int}"]) _add_columns(session_maker, "states", [f"attributes_id {big_int}"])
_create_index(instance, "states", "ix_states_attributes_id") _create_index(session_maker, "states", "ix_states_attributes_id")
elif new_version == 26: elif new_version == 26:
_create_index(instance, "statistics_runs", "ix_statistics_runs_start") _create_index(session_maker, "statistics_runs", "ix_statistics_runs_start")
elif new_version == 27: elif new_version == 27:
_add_columns(instance, "events", [f"data_id {big_int}"]) _add_columns(session_maker, "events", [f"data_id {big_int}"])
_create_index(instance, "events", "ix_events_data_id") _create_index(session_maker, "events", "ix_events_data_id")
elif new_version == 28: elif new_version == 28:
_add_columns(instance, "events", ["origin_idx INTEGER"]) _add_columns(session_maker, "events", ["origin_idx INTEGER"])
# We never use the user_id or parent_id index # We never use the user_id or parent_id index
_drop_index(instance, "events", "ix_events_context_user_id") _drop_index(session_maker, "events", "ix_events_context_user_id")
_drop_index(instance, "events", "ix_events_context_parent_id") _drop_index(session_maker, "events", "ix_events_context_parent_id")
_add_columns( _add_columns(
instance, session_maker,
"states", "states",
[ [
"origin_idx INTEGER", "origin_idx INTEGER",
@ -667,14 +701,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
"context_parent_id VARCHAR(36)", "context_parent_id VARCHAR(36)",
], ],
) )
_create_index(instance, "states", "ix_states_context_id") _create_index(session_maker, "states", "ix_states_context_id")
# Once there are no longer any state_changed events # Once there are no longer any state_changed events
# in the events table we can drop the index on states.event_id # in the events table we can drop the index on states.event_id
else: else:
raise ValueError(f"No schema migration defined for version {new_version}") raise ValueError(f"No schema migration defined for version {new_version}")
def _inspect_schema_version(session): def _inspect_schema_version(session: Session) -> int:
"""Determine the schema version by inspecting the db structure. """Determine the schema version by inspecting the db structure.
When the schema version is not present in the db, either db was just When the schema version is not present in the db, either db was just
@ -696,4 +730,4 @@ def _inspect_schema_version(session):
# Version 1 schema changes not found, this db needs to be migrated. # Version 1 schema changes not found, this db needs to be migrated.
current_version = SchemaChanges(schema_version=0) current_version = SchemaChanges(schema_version=0)
session.add(current_version) session.add(current_version)
return current_version.schema_version return cast(int, current_version.schema_version)

View file

@ -47,7 +47,7 @@ def purge_old_data(
) )
using_sqlite = instance.using_sqlite() using_sqlite = instance.using_sqlite()
with session_scope(session=instance.get_session()) as session: # type: ignore[misc] with session_scope(session=instance.get_session()) as session:
# Purge a max of MAX_ROWS_TO_PURGE, based on the oldest states or events record # Purge a max of MAX_ROWS_TO_PURGE, based on the oldest states or events record
( (
event_ids, event_ids,
@ -515,7 +515,7 @@ def _purge_filtered_events(
def purge_entity_data(instance: Recorder, entity_filter: Callable[[str], bool]) -> bool: def purge_entity_data(instance: Recorder, entity_filter: Callable[[str], bool]) -> bool:
"""Purge states and events of specified entities.""" """Purge states and events of specified entities."""
using_sqlite = instance.using_sqlite() using_sqlite = instance.using_sqlite()
with session_scope(session=instance.get_session()) as session: # type: ignore[misc] with session_scope(session=instance.get_session()) as session:
selected_entity_ids: list[str] = [ selected_entity_ids: list[str] = [
entity_id entity_id
for (entity_id,) in session.query(distinct(States.entity_id)).all() for (entity_id,) in session.query(distinct(States.entity_id)).all()

View file

@ -377,7 +377,7 @@ def _delete_duplicates_from_table(
return (total_deleted_rows, all_non_identical_duplicates) return (total_deleted_rows, all_non_identical_duplicates)
def delete_duplicates(instance: Recorder, session: Session) -> None: def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
"""Identify and delete duplicated statistics. """Identify and delete duplicated statistics.
A backup will be made of duplicated statistics before it is deleted. A backup will be made of duplicated statistics before it is deleted.
@ -391,7 +391,7 @@ def delete_duplicates(instance: Recorder, session: Session) -> None:
if non_identical_duplicates: if non_identical_duplicates:
isotime = dt_util.utcnow().isoformat() isotime = dt_util.utcnow().isoformat()
backup_file_name = f"deleted_statistics.{isotime}.json" backup_file_name = f"deleted_statistics.{isotime}.json"
backup_path = instance.hass.config.path(STORAGE_DIR, backup_file_name) backup_path = hass.config.path(STORAGE_DIR, backup_file_name)
os.makedirs(os.path.dirname(backup_path), exist_ok=True) os.makedirs(os.path.dirname(backup_path), exist_ok=True)
with open(backup_path, "w", encoding="utf8") as backup_file: with open(backup_path, "w", encoding="utf8") as backup_file:
@ -551,7 +551,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
end = start + timedelta(minutes=5) end = start + timedelta(minutes=5)
# Return if we already have 5-minute statistics for the requested period # Return if we already have 5-minute statistics for the requested period
with session_scope(session=instance.get_session()) as session: # type: ignore[misc] with session_scope(session=instance.get_session()) as session:
if session.query(StatisticsRuns).filter_by(start=start).first(): if session.query(StatisticsRuns).filter_by(start=start).first():
_LOGGER.debug("Statistics already compiled for %s-%s", start, end) _LOGGER.debug("Statistics already compiled for %s-%s", start, end)
return True return True
@ -578,7 +578,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
# Insert collected statistics in the database # Insert collected statistics in the database
with session_scope( with session_scope(
session=instance.get_session(), # type: ignore[misc] session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
for stats in platform_stats: for stats in platform_stats:
@ -768,7 +768,7 @@ def _configured_unit(unit: str | None, units: UnitSystem) -> str | None:
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids.""" """Clear statistics for a list of statistic_ids."""
with session_scope(session=instance.get_session()) as session: # type: ignore[misc] with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter( session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids) StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
@ -778,7 +778,7 @@ def update_statistics_metadata(
instance: Recorder, statistic_id: str, unit_of_measurement: str | None instance: Recorder, statistic_id: str, unit_of_measurement: str | None
) -> None: ) -> None:
"""Update statistics metadata for a statistic_id.""" """Update statistics metadata for a statistic_id."""
with session_scope(session=instance.get_session()) as session: # type: ignore[misc] with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter( session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: unit_of_measurement}) ).update({StatisticsMeta.unit_of_measurement: unit_of_measurement})
@ -1376,7 +1376,7 @@ def add_external_statistics(
"""Process an add_external_statistics job.""" """Process an add_external_statistics job."""
with session_scope( with session_scope(
session=instance.get_session(), # type: ignore[misc] session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
old_metadata_dict = get_metadata_with_session( old_metadata_dict = get_metadata_with_session(
@ -1403,7 +1403,7 @@ def adjust_statistics(
) -> bool: ) -> bool:
"""Process an add_statistics job.""" """Process an add_statistics job."""
with session_scope(session=instance.get_session()) as session: # type: ignore[misc] with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session( metadata = get_metadata_with_session(
instance.hass, session, statistic_ids=(statistic_id,) instance.hass, session, statistic_ids=(statistic_id,)
) )

View file

@ -65,8 +65,6 @@ class PurgeTask(RecorderTask):
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Purge the database.""" """Purge the database."""
assert instance.get_session is not None
if purge.purge_old_data( if purge.purge_old_data(
instance, self.purge_before, self.repack, self.apply_filter instance, self.purge_before, self.repack, self.apply_filter
): ):

178
mypy.ini
View file

@ -1710,183 +1710,7 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.recorder] [mypy-homeassistant.components.recorder.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.const]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.core]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.backup]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.executor]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.history]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.pool]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.purge]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.repack]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.run_history]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.services]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.statistics]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.system_health]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.tasks]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.util]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.websocket_api]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true
disallow_subclassing_any = true disallow_subclassing_any = true

View file

@ -138,6 +138,8 @@ async def test_shutdown_closes_connections(hass, recorder_mock):
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(pool.shutdown.mock_calls) == 1 assert len(pool.shutdown.mock_calls) == 1
with pytest.raises(RuntimeError):
assert instance.get_session()
async def test_state_gets_saved_when_set_before_start_event( async def test_state_gets_saved_when_set_before_start_event(

View file

@ -60,9 +60,12 @@ async def test_schema_update_calls(hass):
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
assert recorder.util.async_migration_in_progress(hass) is False assert recorder.util.async_migration_in_progress(hass) is False
instance = recorder.get_instance(hass)
engine = instance.engine
session_maker = instance.get_session
update.assert_has_calls( update.assert_has_calls(
[ [
call(hass.data[DATA_INSTANCE], version + 1, 0) call(hass, engine, session_maker, version + 1, 0)
for version in range(0, models.SCHEMA_VERSION) for version in range(0, models.SCHEMA_VERSION)
] ]
) )
@ -327,10 +330,10 @@ async def test_schema_migrate(hass, start_version):
assert recorder.util.async_migration_in_progress(hass) is not True assert recorder.util.async_migration_in_progress(hass) is not True
def test_invalid_update(): def test_invalid_update(hass):
"""Test that an invalid new version raises an exception.""" """Test that an invalid new version raises an exception."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
migration._apply_update(Mock(), -1, 0) migration._apply_update(hass, Mock(), Mock(), -1, 0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -351,7 +354,9 @@ def test_modify_column(engine_type, substr):
instance.get_session = Mock(return_value=session) instance.get_session = Mock(return_value=session)
engine = Mock() engine = Mock()
engine.dialect.name = engine_type engine.dialect.name = engine_type
migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"]) migration._modify_columns(
instance.get_session, engine, "events", ["event_type VARCHAR(64)"]
)
if substr: if substr:
assert substr in connection.execute.call_args[0][0].text assert substr in connection.execute.call_args[0][0].text
else: else:
@ -365,8 +370,12 @@ def test_forgiving_add_column():
session.execute(text("CREATE TABLE hello (id int)")) session.execute(text("CREATE TABLE hello (id int)"))
instance = Mock() instance = Mock()
instance.get_session = Mock(return_value=session) instance.get_session = Mock(return_value=session)
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"]) migration._add_columns(
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"]) instance.get_session, "hello", ["context_id CHARACTER(36)"]
)
migration._add_columns(
instance.get_session, "hello", ["context_id CHARACTER(36)"]
)
def test_forgiving_add_index(): def test_forgiving_add_index():
@ -376,7 +385,7 @@ def test_forgiving_add_index():
with Session(engine) as session: with Session(engine) as session:
instance = Mock() instance = Mock()
instance.get_session = Mock(return_value=session) instance.get_session = Mock(return_value=session)
migration._create_index(instance, "states", "ix_states_context_id") migration._create_index(instance.get_session, "states", "ix_states_context_id")
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -740,7 +740,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog):
hass = hass_recorder() hass = hass_recorder()
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
delete_duplicates(hass.data[DATA_INSTANCE], session) delete_duplicates(hass, session)
assert "duplicated statistics rows" not in caplog.text assert "duplicated statistics rows" not in caplog.text
assert "Found non identical" not in caplog.text assert "Found non identical" not in caplog.text
assert "Found duplicated" not in caplog.text assert "Found duplicated" not in caplog.text