diff --git a/.strict-typing b/.strict-typing index d0d81ddc5c0..866e8bfd95c 100644 --- a/.strict-typing +++ b/.strict-typing @@ -168,6 +168,7 @@ homeassistant.components.recorder.history homeassistant.components.recorder.purge homeassistant.components.recorder.repack homeassistant.components.recorder.statistics +homeassistant.components.recorder.util homeassistant.components.remote.* homeassistant.components.renault.* homeassistant.components.ridwell.* diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 1f9d8bfaa26..487b8dd22f7 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -3,12 +3,12 @@ from __future__ import annotations from collections.abc import Callable, Generator from contextlib import contextmanager -from datetime import timedelta +from datetime import datetime, timedelta import functools import logging import os import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from awesomeversion import ( AwesomeVersion, @@ -16,6 +16,7 @@ from awesomeversion import ( AwesomeVersionStrategy, ) from sqlalchemy import text +from sqlalchemy.engine.cursor import CursorFetchStrategy from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -95,7 +96,7 @@ def session_scope( session.close() -def commit(session, work): +def commit(session: Session, work: Any) -> bool: """Commit & retry work: Either a model or in a function.""" for _ in range(0, RETRIES): try: @@ -175,12 +176,12 @@ def validate_or_move_away_sqlite_database(dburl: str) -> bool: return True -def dburl_to_path(dburl): +def dburl_to_path(dburl: str) -> str: """Convert the db url into a filesystem path.""" return dburl[len(SQLITE_URL_PREFIX) :] -def last_run_was_recently_clean(cursor): +def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool: """Verify the last recorder run was recently clean.""" cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;") @@ -190,6 +191,7 @@ def last_run_was_recently_clean(cursor): return False last_run_end_time = process_timestamp(dt_util.parse_datetime(end_time[0])) + assert last_run_end_time is not None now = dt_util.utcnow() _LOGGER.debug("The last run ended at: %s (now: %s)", last_run_end_time, now) @@ -200,7 +202,7 @@ def last_run_was_recently_clean(cursor): return True -def basic_sanity_check(cursor): +def basic_sanity_check(cursor: CursorFetchStrategy) -> bool: """Check tables to make sure select does not fail.""" for table in ALL_TABLES: @@ -235,7 +237,7 @@ def validate_sqlite_database(dbpath: str) -> bool: return True -def run_checks_on_open_db(dbpath, cursor): +def run_checks_on_open_db(dbpath: str, cursor: CursorFetchStrategy) -> None: """Run checks that will generate a sqlite3 exception if there is corruption.""" sanity_check_passed = basic_sanity_check(cursor) last_run_was_clean = last_run_was_recently_clean(cursor) @@ -278,14 +280,14 @@ def move_away_broken_database(dbfile: str) -> None: os.rename(path, f"{path}{corrupt_postfix}") -def execute_on_connection(dbapi_connection, statement): +def execute_on_connection(dbapi_connection: Any, statement: str) -> None: """Execute a single statement with a dbapi connection.""" cursor = dbapi_connection.cursor() cursor.execute(statement) cursor.close() -def query_on_connection(dbapi_connection, statement): +def query_on_connection(dbapi_connection: Any, statement: str) -> Any: """Execute a single statement with a dbapi connection and return the result.""" cursor = dbapi_connection.cursor() cursor.execute(statement) @@ -294,30 +296,34 @@ def query_on_connection(dbapi_connection, statement): return result -def _warn_unsupported_dialect(dialect): +def _warn_unsupported_dialect(dialect_name: str) -> None: """Warn about unsupported database version.""" _LOGGER.warning( "Database %s is not supported; Home Assistant supports %s. " "Starting with Home Assistant 2022.2 this will prevent the recorder from " "starting. Please migrate your database to a supported software before then", - dialect, + dialect_name, "MariaDB ≥ 10.3, MySQL ≥ 8.0, PostgreSQL ≥ 12, SQLite ≥ 3.31.0", ) -def _warn_unsupported_version(server_version, dialect, minimum_version): +def _warn_unsupported_version( + server_version: str, dialect_name: str, minimum_version: str +) -> None: """Warn about unsupported database version.""" _LOGGER.warning( "Version %s of %s is not supported; minimum supported version is %s. " "Starting with Home Assistant 2022.2 this will prevent the recorder from " "starting. Please upgrade your database software before then", server_version, - dialect, + dialect_name, minimum_version, ) -def _extract_version_from_server_response(server_response): +def _extract_version_from_server_response( + server_response: str, +) -> AwesomeVersion | None: """Attempt to extract version from server response.""" try: return AwesomeVersion( @@ -330,8 +336,11 @@ def _extract_version_from_server_response(server_response): def setup_connection_for_dialect( - instance, dialect_name, dbapi_connection, first_connection -): + instance: Recorder, + dialect_name: str, + dbapi_connection: Any, + first_connection: bool, +) -> None: """Execute statements needed for dialect connection.""" # Returns False if the the connection needs to be setup # on the next connection, returns True if the connection @@ -406,7 +415,7 @@ def setup_connection_for_dialect( _warn_unsupported_dialect(dialect_name) -def end_incomplete_runs(session, start_time): +def end_incomplete_runs(session: Session, start_time: datetime) -> None: """End any incomplete recorder runs.""" for run in session.query(RecorderRuns).filter_by(end=None): run.closed_incorrect = True @@ -423,9 +432,9 @@ def retryable_database_job(description: str) -> Callable: The job should return True if it finished, and False if it needs to be rescheduled. """ - def decorator(job: Callable) -> Callable: + def decorator(job: Callable[[Any], bool]) -> Callable: @functools.wraps(job) - def wrapper(instance: Recorder, *args, **kwargs): + def wrapper(instance: Recorder, *args: Any, **kwargs: Any) -> bool: try: return job(instance, *args, **kwargs) except OperationalError as err: @@ -451,7 +460,7 @@ def retryable_database_job(description: str) -> Callable: return decorator -def perodic_db_cleanups(instance: Recorder): +def perodic_db_cleanups(instance: Recorder) -> None: """Run any database cleanups that need to happen perodiclly. These cleanups will happen nightly or after any purge. @@ -465,7 +474,7 @@ def perodic_db_cleanups(instance: Recorder): @contextmanager -def write_lock_db_sqlite(instance: Recorder): +def write_lock_db_sqlite(instance: Recorder) -> Generator[None, None, None]: """Lock database for writes.""" assert instance.engine is not None with instance.engine.connect() as connection: @@ -490,4 +499,5 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool: """ if DATA_INSTANCE not in hass.data: return False - return hass.data[DATA_INSTANCE].migration_in_progress + instance: Recorder = hass.data[DATA_INSTANCE] + return instance.migration_in_progress diff --git a/mypy.ini b/mypy.ini index 344cb82b372..98a765d407c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1650,6 +1650,17 @@ no_implicit_optional = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.recorder.util] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.remote.*] check_untyped_defs = true disallow_incomplete_defs = true