Add strict typing for recorder util (#68681)
This commit is contained in:
parent
4dc8aff3d5
commit
225f7a989b
3 changed files with 44 additions and 22 deletions
|
@ -168,6 +168,7 @@ homeassistant.components.recorder.history
|
||||||
homeassistant.components.recorder.purge
|
homeassistant.components.recorder.purge
|
||||||
homeassistant.components.recorder.repack
|
homeassistant.components.recorder.repack
|
||||||
homeassistant.components.recorder.statistics
|
homeassistant.components.recorder.statistics
|
||||||
|
homeassistant.components.recorder.util
|
||||||
homeassistant.components.remote.*
|
homeassistant.components.remote.*
|
||||||
homeassistant.components.renault.*
|
homeassistant.components.renault.*
|
||||||
homeassistant.components.ridwell.*
|
homeassistant.components.ridwell.*
|
||||||
|
|
|
@ -3,12 +3,12 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import timedelta
|
from datetime import datetime, timedelta
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from awesomeversion import (
|
from awesomeversion import (
|
||||||
AwesomeVersion,
|
AwesomeVersion,
|
||||||
|
@ -16,6 +16,7 @@ from awesomeversion import (
|
||||||
AwesomeVersionStrategy,
|
AwesomeVersionStrategy,
|
||||||
)
|
)
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine.cursor import CursorFetchStrategy
|
||||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
@ -95,7 +96,7 @@ def session_scope(
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
def commit(session, work):
|
def commit(session: Session, work: Any) -> bool:
|
||||||
"""Commit & retry work: Either a model or in a function."""
|
"""Commit & retry work: Either a model or in a function."""
|
||||||
for _ in range(0, RETRIES):
|
for _ in range(0, RETRIES):
|
||||||
try:
|
try:
|
||||||
|
@ -175,12 +176,12 @@ def validate_or_move_away_sqlite_database(dburl: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def dburl_to_path(dburl):
|
def dburl_to_path(dburl: str) -> str:
|
||||||
"""Convert the db url into a filesystem path."""
|
"""Convert the db url into a filesystem path."""
|
||||||
return dburl[len(SQLITE_URL_PREFIX) :]
|
return dburl[len(SQLITE_URL_PREFIX) :]
|
||||||
|
|
||||||
|
|
||||||
def last_run_was_recently_clean(cursor):
|
def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool:
|
||||||
"""Verify the last recorder run was recently clean."""
|
"""Verify the last recorder run was recently clean."""
|
||||||
|
|
||||||
cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;")
|
cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;")
|
||||||
|
@ -190,6 +191,7 @@ def last_run_was_recently_clean(cursor):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
last_run_end_time = process_timestamp(dt_util.parse_datetime(end_time[0]))
|
last_run_end_time = process_timestamp(dt_util.parse_datetime(end_time[0]))
|
||||||
|
assert last_run_end_time is not None
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
|
|
||||||
_LOGGER.debug("The last run ended at: %s (now: %s)", last_run_end_time, now)
|
_LOGGER.debug("The last run ended at: %s (now: %s)", last_run_end_time, now)
|
||||||
|
@ -200,7 +202,7 @@ def last_run_was_recently_clean(cursor):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def basic_sanity_check(cursor):
|
def basic_sanity_check(cursor: CursorFetchStrategy) -> bool:
|
||||||
"""Check tables to make sure select does not fail."""
|
"""Check tables to make sure select does not fail."""
|
||||||
|
|
||||||
for table in ALL_TABLES:
|
for table in ALL_TABLES:
|
||||||
|
@ -235,7 +237,7 @@ def validate_sqlite_database(dbpath: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_checks_on_open_db(dbpath, cursor):
|
def run_checks_on_open_db(dbpath: str, cursor: CursorFetchStrategy) -> None:
|
||||||
"""Run checks that will generate a sqlite3 exception if there is corruption."""
|
"""Run checks that will generate a sqlite3 exception if there is corruption."""
|
||||||
sanity_check_passed = basic_sanity_check(cursor)
|
sanity_check_passed = basic_sanity_check(cursor)
|
||||||
last_run_was_clean = last_run_was_recently_clean(cursor)
|
last_run_was_clean = last_run_was_recently_clean(cursor)
|
||||||
|
@ -278,14 +280,14 @@ def move_away_broken_database(dbfile: str) -> None:
|
||||||
os.rename(path, f"{path}{corrupt_postfix}")
|
os.rename(path, f"{path}{corrupt_postfix}")
|
||||||
|
|
||||||
|
|
||||||
def execute_on_connection(dbapi_connection, statement):
|
def execute_on_connection(dbapi_connection: Any, statement: str) -> None:
|
||||||
"""Execute a single statement with a dbapi connection."""
|
"""Execute a single statement with a dbapi connection."""
|
||||||
cursor = dbapi_connection.cursor()
|
cursor = dbapi_connection.cursor()
|
||||||
cursor.execute(statement)
|
cursor.execute(statement)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
def query_on_connection(dbapi_connection, statement):
|
def query_on_connection(dbapi_connection: Any, statement: str) -> Any:
|
||||||
"""Execute a single statement with a dbapi connection and return the result."""
|
"""Execute a single statement with a dbapi connection and return the result."""
|
||||||
cursor = dbapi_connection.cursor()
|
cursor = dbapi_connection.cursor()
|
||||||
cursor.execute(statement)
|
cursor.execute(statement)
|
||||||
|
@ -294,30 +296,34 @@ def query_on_connection(dbapi_connection, statement):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _warn_unsupported_dialect(dialect):
|
def _warn_unsupported_dialect(dialect_name: str) -> None:
|
||||||
"""Warn about unsupported database version."""
|
"""Warn about unsupported database version."""
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Database %s is not supported; Home Assistant supports %s. "
|
"Database %s is not supported; Home Assistant supports %s. "
|
||||||
"Starting with Home Assistant 2022.2 this will prevent the recorder from "
|
"Starting with Home Assistant 2022.2 this will prevent the recorder from "
|
||||||
"starting. Please migrate your database to a supported software before then",
|
"starting. Please migrate your database to a supported software before then",
|
||||||
dialect,
|
dialect_name,
|
||||||
"MariaDB ≥ 10.3, MySQL ≥ 8.0, PostgreSQL ≥ 12, SQLite ≥ 3.31.0",
|
"MariaDB ≥ 10.3, MySQL ≥ 8.0, PostgreSQL ≥ 12, SQLite ≥ 3.31.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _warn_unsupported_version(server_version, dialect, minimum_version):
|
def _warn_unsupported_version(
|
||||||
|
server_version: str, dialect_name: str, minimum_version: str
|
||||||
|
) -> None:
|
||||||
"""Warn about unsupported database version."""
|
"""Warn about unsupported database version."""
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Version %s of %s is not supported; minimum supported version is %s. "
|
"Version %s of %s is not supported; minimum supported version is %s. "
|
||||||
"Starting with Home Assistant 2022.2 this will prevent the recorder from "
|
"Starting with Home Assistant 2022.2 this will prevent the recorder from "
|
||||||
"starting. Please upgrade your database software before then",
|
"starting. Please upgrade your database software before then",
|
||||||
server_version,
|
server_version,
|
||||||
dialect,
|
dialect_name,
|
||||||
minimum_version,
|
minimum_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_version_from_server_response(server_response):
|
def _extract_version_from_server_response(
|
||||||
|
server_response: str,
|
||||||
|
) -> AwesomeVersion | None:
|
||||||
"""Attempt to extract version from server response."""
|
"""Attempt to extract version from server response."""
|
||||||
try:
|
try:
|
||||||
return AwesomeVersion(
|
return AwesomeVersion(
|
||||||
|
@ -330,8 +336,11 @@ def _extract_version_from_server_response(server_response):
|
||||||
|
|
||||||
|
|
||||||
def setup_connection_for_dialect(
|
def setup_connection_for_dialect(
|
||||||
instance, dialect_name, dbapi_connection, first_connection
|
instance: Recorder,
|
||||||
):
|
dialect_name: str,
|
||||||
|
dbapi_connection: Any,
|
||||||
|
first_connection: bool,
|
||||||
|
) -> None:
|
||||||
"""Execute statements needed for dialect connection."""
|
"""Execute statements needed for dialect connection."""
|
||||||
# Returns False if the the connection needs to be setup
|
# Returns False if the the connection needs to be setup
|
||||||
# on the next connection, returns True if the connection
|
# on the next connection, returns True if the connection
|
||||||
|
@ -406,7 +415,7 @@ def setup_connection_for_dialect(
|
||||||
_warn_unsupported_dialect(dialect_name)
|
_warn_unsupported_dialect(dialect_name)
|
||||||
|
|
||||||
|
|
||||||
def end_incomplete_runs(session, start_time):
|
def end_incomplete_runs(session: Session, start_time: datetime) -> None:
|
||||||
"""End any incomplete recorder runs."""
|
"""End any incomplete recorder runs."""
|
||||||
for run in session.query(RecorderRuns).filter_by(end=None):
|
for run in session.query(RecorderRuns).filter_by(end=None):
|
||||||
run.closed_incorrect = True
|
run.closed_incorrect = True
|
||||||
|
@ -423,9 +432,9 @@ def retryable_database_job(description: str) -> Callable:
|
||||||
The job should return True if it finished, and False if it needs to be rescheduled.
|
The job should return True if it finished, and False if it needs to be rescheduled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(job: Callable) -> Callable:
|
def decorator(job: Callable[[Any], bool]) -> Callable:
|
||||||
@functools.wraps(job)
|
@functools.wraps(job)
|
||||||
def wrapper(instance: Recorder, *args, **kwargs):
|
def wrapper(instance: Recorder, *args: Any, **kwargs: Any) -> bool:
|
||||||
try:
|
try:
|
||||||
return job(instance, *args, **kwargs)
|
return job(instance, *args, **kwargs)
|
||||||
except OperationalError as err:
|
except OperationalError as err:
|
||||||
|
@ -451,7 +460,7 @@ def retryable_database_job(description: str) -> Callable:
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def perodic_db_cleanups(instance: Recorder):
|
def perodic_db_cleanups(instance: Recorder) -> None:
|
||||||
"""Run any database cleanups that need to happen perodiclly.
|
"""Run any database cleanups that need to happen perodiclly.
|
||||||
|
|
||||||
These cleanups will happen nightly or after any purge.
|
These cleanups will happen nightly or after any purge.
|
||||||
|
@ -465,7 +474,7 @@ def perodic_db_cleanups(instance: Recorder):
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def write_lock_db_sqlite(instance: Recorder):
|
def write_lock_db_sqlite(instance: Recorder) -> Generator[None, None, None]:
|
||||||
"""Lock database for writes."""
|
"""Lock database for writes."""
|
||||||
assert instance.engine is not None
|
assert instance.engine is not None
|
||||||
with instance.engine.connect() as connection:
|
with instance.engine.connect() as connection:
|
||||||
|
@ -490,4 +499,5 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool:
|
||||||
"""
|
"""
|
||||||
if DATA_INSTANCE not in hass.data:
|
if DATA_INSTANCE not in hass.data:
|
||||||
return False
|
return False
|
||||||
return hass.data[DATA_INSTANCE].migration_in_progress
|
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||||
|
return instance.migration_in_progress
|
||||||
|
|
11
mypy.ini
11
mypy.ini
|
@ -1650,6 +1650,17 @@ no_implicit_optional = true
|
||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.components.recorder.util]
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
disallow_subclassing_any = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unreachable = true
|
||||||
|
|
||||||
[mypy-homeassistant.components.remote.*]
|
[mypy-homeassistant.components.remote.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue