Additional strict typing for recorder (#68860)

This commit is contained in:
J. Nick Koston 2022-03-30 06:20:44 -10:00 committed by GitHub
parent fa33ac73f3
commit d75f577b88
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 109 additions and 39 deletions

View file

@ -172,8 +172,12 @@ homeassistant.components.pure_energie.*
homeassistant.components.rainmachine.* homeassistant.components.rainmachine.*
homeassistant.components.rdw.* homeassistant.components.rdw.*
homeassistant.components.recollect_waste.* homeassistant.components.recollect_waste.*
homeassistant.components.recorder.models homeassistant.components.recorder
homeassistant.components.recorder.const
homeassistant.components.recorder.backup
homeassistant.components.recorder.executor
homeassistant.components.recorder.history homeassistant.components.recorder.history
homeassistant.components.recorder.models
homeassistant.components.recorder.pool homeassistant.components.recorder.pool
homeassistant.components.recorder.purge homeassistant.components.recorder.purge
homeassistant.components.recorder.repack homeassistant.components.recorder.repack

View file

@ -11,7 +11,7 @@ import queue
import sqlite3 import sqlite3
import threading import threading
import time import time
from typing import Any, TypeVar from typing import Any, TypeVar, cast
from lru import LRU # pylint: disable=no-name-in-module from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
@ -214,7 +214,8 @@ MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
def get_instance(hass: HomeAssistant) -> Recorder: def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance.""" """Get the recorder instance."""
return hass.data[DATA_INSTANCE] instance: Recorder = hass.data[DATA_INSTANCE]
return instance
@bind_hass @bind_hass
@ -225,10 +226,13 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
""" """
if DATA_INSTANCE not in hass.data: if DATA_INSTANCE not in hass.data:
return False return False
return hass.data[DATA_INSTANCE].entity_filter(entity_id) instance: Recorder = hass.data[DATA_INSTANCE]
return instance.entity_filter(entity_id)
def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns | None: def run_information(
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run. """Return information about current run.
There is also the run that covers point_in_time. There is also the run that covers point_in_time.
@ -241,21 +245,20 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns
def run_information_from_instance( def run_information_from_instance(
hass, point_in_time: datetime | None = None hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None: ) -> RecorderRuns | None:
"""Return information about current run from the existing instance. """Return information about current run from the existing instance.
Does not query the database for older runs. Does not query the database for older runs.
""" """
ins = hass.data[DATA_INSTANCE] ins = get_instance(hass)
if point_in_time is None or point_in_time > ins.recording_start: if point_in_time is None or point_in_time > ins.recording_start:
return ins.run_info return ins.run_info
return None return None
def run_information_with_session( def run_information_with_session(
session, point_in_time: datetime | None = None session: Session, point_in_time: datetime | None = None
) -> RecorderRuns | None: ) -> RecorderRuns | None:
"""Return information about current run from the database.""" """Return information about current run from the database."""
recorder_runs = RecorderRuns recorder_runs = RecorderRuns
@ -266,9 +269,9 @@ def run_information_with_session(
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time) (recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
) )
res = query.first() if (res := query.first()) is not None:
if res:
session.expunge(res) session.expunge(res)
return cast(RecorderRuns, res)
return res return res
@ -318,9 +321,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return await instance.async_db_ready return await instance.async_db_ready
async def _process_recorder_platform(hass, domain, platform): async def _process_recorder_platform(
hass: HomeAssistant, domain: str, platform: Any
) -> None:
"""Process a recorder platform.""" """Process a recorder platform."""
hass.data[DOMAIN][domain] = platform platforms: dict[str, Any] = hass.data[DOMAIN]
platforms[domain] = platform
if hasattr(platform, "exclude_attributes"): if hasattr(platform, "exclude_attributes"):
hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass) hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass)
@ -586,11 +592,11 @@ class Recorder(threading.Thread):
self.db_url = uri self.db_url = uri
self.db_max_retries = db_max_retries self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait self.db_retry_wait = db_retry_wait
self.async_db_ready: asyncio.Future = asyncio.Future() self.async_db_ready: asyncio.Future[bool] = asyncio.Future()
self.async_recorder_ready = asyncio.Event() self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event() self._queue_watch = threading.Event()
self.engine: Engine | None = None self.engine: Engine | None = None
self.run_info: Any = None self.run_info: RecorderRuns | None = None
self.entity_filter = entity_filter self.entity_filter = entity_filter
self.exclude_t = exclude_t self.exclude_t = exclude_t
@ -616,12 +622,12 @@ class Recorder(threading.Thread):
self.enabled = True self.enabled = True
def set_enable(self, enable): def set_enable(self, enable: bool) -> None:
"""Enable or disable recording events and states.""" """Enable or disable recording events and states."""
self.enabled = enable self.enabled = enable
@callback @callback
def async_start_executor(self): def async_start_executor(self) -> None:
"""Start the executor.""" """Start the executor."""
self._db_executor = DBInterruptibleThreadPoolExecutor( self._db_executor = DBInterruptibleThreadPoolExecutor(
thread_name_prefix=DB_WORKER_PREFIX, thread_name_prefix=DB_WORKER_PREFIX,
@ -629,13 +635,13 @@ class Recorder(threading.Thread):
shutdown_hook=self._shutdown_pool, shutdown_hook=self._shutdown_pool,
) )
def _shutdown_pool(self): def _shutdown_pool(self) -> None:
"""Close the dbpool connections in the current thread.""" """Close the dbpool connections in the current thread."""
if hasattr(self.engine.pool, "shutdown"): if self.engine and hasattr(self.engine.pool, "shutdown"):
self.engine.pool.shutdown() self.engine.pool.shutdown()
@callback @callback
def async_initialize(self): def async_initialize(self) -> None:
"""Initialize the recorder.""" """Initialize the recorder."""
self._event_listener = self.hass.bus.async_listen( self._event_listener = self.hass.bus.async_listen(
MATCH_ALL, self.event_listener, event_filter=self._async_event_filter MATCH_ALL, self.event_listener, event_filter=self._async_event_filter
@ -658,7 +664,7 @@ class Recorder(threading.Thread):
self._db_executor = None self._db_executor = None
@callback @callback
def _async_check_queue(self, *_): def _async_check_queue(self, *_: Any) -> None:
"""Periodic check of the queue size to ensure we do not exaust memory. """Periodic check of the queue size to ensure we do not exaust memory.
The queue grows during migraton or if something really goes wrong. The queue grows during migraton or if something really goes wrong.
@ -704,21 +710,23 @@ class Recorder(threading.Thread):
# Unknown what it is. # Unknown what it is.
return True return True
def do_adhoc_purge(self, **kwargs): def do_adhoc_purge(self, **kwargs: Any) -> None:
"""Trigger an adhoc purge retaining keep_days worth of data.""" """Trigger an adhoc purge retaining keep_days worth of data."""
keep_days = kwargs.get(ATTR_KEEP_DAYS, self.keep_days) keep_days = kwargs.get(ATTR_KEEP_DAYS, self.keep_days)
repack = kwargs.get(ATTR_REPACK) repack = cast(bool, kwargs[ATTR_REPACK])
apply_filter = kwargs.get(ATTR_APPLY_FILTER) apply_filter = cast(bool, kwargs[ATTR_APPLY_FILTER])
purge_before = dt_util.utcnow() - timedelta(days=keep_days) purge_before = dt_util.utcnow() - timedelta(days=keep_days)
self.queue.put(PurgeTask(purge_before, repack, apply_filter)) self.queue.put(PurgeTask(purge_before, repack, apply_filter))
def do_adhoc_purge_entities(self, entity_ids, domains, entity_globs): def do_adhoc_purge_entities(
self, entity_ids: set[str], domains: list[str], entity_globs: list[str]
) -> None:
"""Trigger an adhoc purge of requested entities.""" """Trigger an adhoc purge of requested entities."""
entity_filter = generate_filter(domains, entity_ids, [], [], entity_globs) entity_filter = generate_filter(domains, list(entity_ids), [], [], entity_globs)
self.queue.put(PurgeEntitiesTask(entity_filter)) self.queue.put(PurgeEntitiesTask(entity_filter))
def do_adhoc_statistics(self, **kwargs): def do_adhoc_statistics(self, **kwargs: Any) -> None:
"""Trigger an adhoc statistics run.""" """Trigger an adhoc statistics run."""
if not (start := kwargs.get("start")): if not (start := kwargs.get("start")):
start = statistics.get_start_time() start = statistics.get_start_time()
@ -812,22 +820,26 @@ class Recorder(threading.Thread):
self.queue.put(StatisticsTask(start)) self.queue.put(StatisticsTask(start))
@callback @callback
def async_adjust_statistics(self, statistic_id, start_time, sum_adjustment): def async_adjust_statistics(
self, statistic_id: str, start_time: datetime, sum_adjustment: float
) -> None:
"""Adjust statistics.""" """Adjust statistics."""
self.queue.put(AdjustStatisticsTask(statistic_id, start_time, sum_adjustment)) self.queue.put(AdjustStatisticsTask(statistic_id, start_time, sum_adjustment))
@callback @callback
def async_clear_statistics(self, statistic_ids): def async_clear_statistics(self, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids.""" """Clear statistics for a list of statistic_ids."""
self.queue.put(ClearStatisticsTask(statistic_ids)) self.queue.put(ClearStatisticsTask(statistic_ids))
@callback @callback
def async_update_statistics_metadata(self, statistic_id, unit_of_measurement): def async_update_statistics_metadata(
self, statistic_id: str, unit_of_measurement: str | None
) -> None:
"""Update statistics metadata for a statistic_id.""" """Update statistics metadata for a statistic_id."""
self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement)) self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement))
@callback @callback
def async_external_statistics(self, metadata, stats): def async_external_statistics(self, metadata: dict, stats: Iterable[dict]) -> None:
"""Schedule external statistics.""" """Schedule external statistics."""
self.queue.put(ExternalStatisticsTask(metadata, stats)) self.queue.put(ExternalStatisticsTask(metadata, stats))
@ -995,7 +1007,7 @@ class Recorder(threading.Thread):
def _lock_database(self, task: DatabaseLockTask) -> None: def _lock_database(self, task: DatabaseLockTask) -> None:
@callback @callback
def _async_set_database_locked(task: DatabaseLockTask): def _async_set_database_locked(task: DatabaseLockTask) -> None:
task.database_locked.set() task.database_locked.set()
with write_lock_db_sqlite(self): with write_lock_db_sqlite(self):
@ -1285,8 +1297,11 @@ class Recorder(threading.Thread):
kwargs: dict[str, Any] = {} kwargs: dict[str, Any] = {}
self._completed_first_database_setup = False self._completed_first_database_setup = False
def setup_recorder_connection(dbapi_connection, connection_record): def setup_recorder_connection(
dbapi_connection: Any, connection_record: Any
) -> None:
"""Dbapi specific connection settings.""" """Dbapi specific connection settings."""
assert self.engine is not None
setup_connection_for_dialect( setup_connection_for_dialect(
self, self,
self.engine.dialect.name, self.engine.dialect.name,
@ -1366,6 +1381,7 @@ class Recorder(threading.Thread):
"""End the recorder session.""" """End the recorder session."""
if self.event_session is None: if self.event_session is None:
return return
assert self.run_info is not None
try: try:
self.run_info.end = dt_util.utcnow() self.run_info.end = dt_util.utcnow()
self.event_session.add(self.run_info) self.event_session.add(self.run_info)

View file

@ -10,7 +10,9 @@ import weakref
from homeassistant.util.executor import InterruptibleThreadPoolExecutor from homeassistant.util.executor import InterruptibleThreadPoolExecutor
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs): def _worker_with_shutdown_hook(
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any
) -> None:
"""Create a worker that calls a function after its finished.""" """Create a worker that calls a function after its finished."""
_worker(*args, **kwargs) _worker(*args, **kwargs)
shutdown_hook() shutdown_hook()
@ -37,7 +39,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
# When the executor gets lost, the weakref callback will wake up # When the executor gets lost, the weakref callback will wake up
# the worker threads. # the worker threads.
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name def weakref_cb(_: Any, q=self._work_queue) -> None: # type: ignore[no-untyped-def] # pylint: disable=invalid-name
q.put(None) q.put(None)
num_threads = len(self._threads) num_threads = len(self._threads)

View file

@ -2,6 +2,7 @@
import contextlib import contextlib
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any
import sqlalchemy import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -43,8 +44,9 @@ def raise_if_exception_missing_str(ex, match_substrs):
raise ex raise ex
def get_schema_version(instance): def get_schema_version(instance: Any) -> int:
"""Get the schema version.""" """Get the schema version."""
assert instance.get_session is not None
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
res = ( res = (
session.query(SchemaChanges) session.query(SchemaChanges)
@ -62,13 +64,14 @@ def get_schema_version(instance):
return current_version return current_version
def schema_is_current(current_version): def schema_is_current(current_version: int) -> bool:
"""Check if the schema is current.""" """Check if the schema is current."""
return current_version == SCHEMA_VERSION return current_version == SCHEMA_VERSION
def migrate_schema(instance, current_version): def migrate_schema(instance: Any, current_version: int) -> None:
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
assert instance.get_session is not None
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
for version in range(current_version, SCHEMA_VERSION): for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1 new_version = version + 1

View file

@ -291,6 +291,7 @@ def _purge_old_recorder_runs(
) -> None: ) -> None:
"""Purge all old recorder runs.""" """Purge all old recorder runs."""
# Recorder runs is small, no need to batch run it # Recorder runs is small, no need to batch run it
assert instance.run_info is not None
deleted_rows = ( deleted_rows = (
session.query(RecorderRuns) session.query(RecorderRuns)
.filter(RecorderRuns.start < purge_before) .filter(RecorderRuns.start < purge_before)

View file

@ -1694,7 +1694,40 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.recorder.models] [mypy-homeassistant.components.recorder]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.const]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.backup]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.executor]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true
disallow_subclassing_any = true disallow_subclassing_any = true
@ -1716,6 +1749,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.pool] [mypy-homeassistant.components.recorder.pool]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true