From 606d5441573f52de961010551b34d23aef3317dc Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 22 Jul 2022 11:58:26 +0200 Subject: [PATCH] Use recorder get_instance function to improve typing (#75567) --- homeassistant/components/recorder/__init__.py | 11 ++--- homeassistant/components/recorder/backup.py | 8 ++-- .../components/recorder/statistics.py | 7 +-- homeassistant/components/recorder/util.py | 10 +++- .../components/recorder/websocket_api.py | 24 ++++------ tests/components/recorder/common.py | 5 +- tests/components/recorder/test_backup.py | 8 ++-- tests/components/recorder/test_init.py | 46 ++++++++++--------- tests/components/recorder/test_migrate.py | 11 ++--- tests/components/recorder/test_statistics.py | 6 +-- tests/components/recorder/test_util.py | 16 +++---- .../components/recorder/test_websocket_api.py | 3 +- tests/components/sensor/test_recorder.py | 7 ++- 13 files changed, 77 insertions(+), 85 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 4063e443e8b..238b013f366 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -31,6 +31,7 @@ from .const import ( from .core import Recorder from .services import async_register_services from .tasks import AddRecorderPlatformTask +from .util import get_instance _LOGGER = logging.getLogger(__name__) @@ -108,12 +109,6 @@ CONFIG_SCHEMA = vol.Schema( ) -def get_instance(hass: HomeAssistant) -> Recorder: - """Get the recorder instance.""" - instance: Recorder = hass.data[DATA_INSTANCE] - return instance - - @bind_hass def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool: """Check if an entity is being recorded. @@ -122,7 +117,7 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool: """ if DATA_INSTANCE not in hass.data: return False - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) return instance.entity_filter(entity_id) @@ -177,5 +172,5 @@ async def _process_recorder_platform( hass: HomeAssistant, domain: str, platform: Any ) -> None: """Process a recorder platform.""" - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) instance.queue_task(AddRecorderPlatformTask(domain, platform)) diff --git a/homeassistant/components/recorder/backup.py b/homeassistant/components/recorder/backup.py index cec9f85748b..a1f6f4f39bc 100644 --- a/homeassistant/components/recorder/backup.py +++ b/homeassistant/components/recorder/backup.py @@ -4,9 +4,7 @@ from logging import getLogger from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from . import Recorder -from .const import DATA_INSTANCE -from .util import async_migration_in_progress +from .util import async_migration_in_progress, get_instance _LOGGER = getLogger(__name__) @@ -14,7 +12,7 @@ _LOGGER = getLogger(__name__) async def async_pre_backup(hass: HomeAssistant) -> None: """Perform operations before a backup starts.""" _LOGGER.info("Backup start notification, locking database for writes") - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) if async_migration_in_progress(hass): raise HomeAssistantError("Database migration in progress") await instance.lock_database() @@ -22,7 +20,7 @@ async def async_pre_backup(hass: HomeAssistant) -> None: async def async_post_backup(hass: HomeAssistant) -> None: """Perform operations after a backup finishes.""" - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) _LOGGER.info("Backup end notification, releasing write lock") if not instance.unlock_database(): raise HomeAssistantError("Could not release database write lock") diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 4ebd5e17902..bce77e8a31e 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -41,7 +41,7 @@ import homeassistant.util.temperature as temperature_util from homeassistant.util.unit_system import UnitSystem import homeassistant.util.volume as volume_util -from .const import DATA_INSTANCE, DOMAIN, MAX_ROWS_TO_PURGE, SupportedDialect +from .const import DOMAIN, MAX_ROWS_TO_PURGE, SupportedDialect from .db_schema import Statistics, StatisticsMeta, StatisticsRuns, StatisticsShortTerm from .models import ( StatisticData, @@ -53,6 +53,7 @@ from .models import ( from .util import ( execute, execute_stmt_lambda_element, + get_instance, retryable_database_job, session_scope, ) @@ -209,7 +210,7 @@ def async_setup(hass: HomeAssistant) -> None: @callback def _async_entity_id_changed(event: Event) -> None: - hass.data[DATA_INSTANCE].async_update_statistics_metadata( + get_instance(hass).async_update_statistics_metadata( event.data["old_entity_id"], new_statistic_id=event.data["entity_id"] ) @@ -1385,7 +1386,7 @@ def _async_import_statistics( statistic["last_reset"] = dt_util.as_utc(last_reset) # Insert job in recorder's queue - hass.data[DATA_INSTANCE].async_import_statistics(metadata, statistics) + get_instance(hass).async_import_statistics(metadata, statistics) @callback diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index c1fbc831987..fdf42665ef5 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -80,7 +80,7 @@ def session_scope( ) -> Generator[Session, None, None]: """Provide a transactional scope around a series of operations.""" if session is None and hass is not None: - session = hass.data[DATA_INSTANCE].get_session() + session = get_instance(hass).get_session() if session is None: raise RuntimeError("Session required") @@ -559,7 +559,7 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool: """ if DATA_INSTANCE not in hass.data: return False - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) return instance.migration_in_progress @@ -577,3 +577,9 @@ def second_sunday(year: int, month: int) -> date: def is_second_sunday(date_time: datetime) -> bool: """Check if a time is the second sunday of the month.""" return bool(second_sunday(date_time.year, date_time.month).day == date_time.day) + + +def get_instance(hass: HomeAssistant) -> Recorder: + """Get the recorder instance.""" + instance: Recorder = hass.data[DATA_INSTANCE] + return instance diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 4ba5f3c8a8b..c143d8b4f0b 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING import voluptuous as vol @@ -11,17 +10,14 @@ from homeassistant.core import HomeAssistant, callback, valid_entity_id from homeassistant.helpers import config_validation as cv from homeassistant.util import dt as dt_util -from .const import DATA_INSTANCE, MAX_QUEUE_BACKLOG +from .const import MAX_QUEUE_BACKLOG from .statistics import ( async_add_external_statistics, async_import_statistics, list_statistic_ids, validate_statistics, ) -from .util import async_migration_in_progress - -if TYPE_CHECKING: - from . import Recorder +from .util import async_migration_in_progress, get_instance _LOGGER: logging.Logger = logging.getLogger(__package__) @@ -50,7 +46,7 @@ async def ws_validate_statistics( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Fetch a list of available statistic_id.""" - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) statistic_ids = await instance.async_add_executor_job( validate_statistics, hass, @@ -74,7 +70,7 @@ def ws_clear_statistics( Note: The WS call posts a job to the recorder's queue and then returns, it doesn't wait until the job is completed. """ - hass.data[DATA_INSTANCE].async_clear_statistics(msg["statistic_ids"]) + get_instance(hass).async_clear_statistics(msg["statistic_ids"]) connection.send_result(msg["id"]) @@ -89,7 +85,7 @@ async def ws_get_statistics_metadata( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Get metadata for a list of statistic_ids.""" - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) statistic_ids = await instance.async_add_executor_job( list_statistic_ids, hass, msg.get("statistic_ids") ) @@ -109,7 +105,7 @@ def ws_update_statistics_metadata( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Update statistics metadata for a statistic_id.""" - hass.data[DATA_INSTANCE].async_update_statistics_metadata( + get_instance(hass).async_update_statistics_metadata( msg["statistic_id"], new_unit_of_measurement=msg["unit_of_measurement"] ) connection.send_result(msg["id"]) @@ -137,7 +133,7 @@ def ws_adjust_sum_statistics( connection.send_error(msg["id"], "invalid_start_time", "Invalid start time") return - hass.data[DATA_INSTANCE].async_adjust_statistics( + get_instance(hass).async_adjust_statistics( msg["statistic_id"], start_time, msg["adjustment"] ) connection.send_result(msg["id"]) @@ -193,7 +189,7 @@ def ws_info( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Return status of the recorder.""" - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) backlog = instance.backlog if instance else None migration_in_progress = async_migration_in_progress(hass) @@ -219,7 +215,7 @@ async def ws_backup_start( """Backup start notification.""" _LOGGER.info("Backup start notification, locking database for writes") - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) try: await instance.lock_database() except TimeoutError as err: @@ -236,7 +232,7 @@ async def ws_backup_end( ) -> None: """Backup end notification.""" - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) _LOGGER.info("Backup end notification, releasing write lock") if not instance.unlock_database(): connection.send_error( diff --git a/tests/components/recorder/common.py b/tests/components/recorder/common.py index 20df89eca5b..083630c7ea8 100644 --- a/tests/components/recorder/common.py +++ b/tests/components/recorder/common.py @@ -62,7 +62,7 @@ def wait_recording_done(hass: HomeAssistant) -> None: hass.block_till_done() trigger_db_commit(hass) hass.block_till_done() - hass.data[recorder.DATA_INSTANCE].block_till_done() + recorder.get_instance(hass).block_till_done() hass.block_till_done() @@ -105,8 +105,7 @@ def async_trigger_db_commit(hass: HomeAssistant) -> None: async def async_recorder_block_till_done(hass: HomeAssistant) -> None: """Non blocking version of recorder.block_till_done().""" - instance: recorder.Recorder = hass.data[recorder.DATA_INSTANCE] - await hass.async_add_executor_job(instance.block_till_done) + await hass.async_add_executor_job(recorder.get_instance(hass).block_till_done) def corrupt_db_file(test_db_file): diff --git a/tests/components/recorder/test_backup.py b/tests/components/recorder/test_backup.py index d7c5a55b56a..e829c2aa13b 100644 --- a/tests/components/recorder/test_backup.py +++ b/tests/components/recorder/test_backup.py @@ -13,7 +13,7 @@ from homeassistant.exceptions import HomeAssistantError async def test_async_pre_backup(hass: HomeAssistant, recorder_mock) -> None: """Test pre backup.""" with patch( - "homeassistant.components.recorder.backup.Recorder.lock_database" + "homeassistant.components.recorder.core.Recorder.lock_database" ) as lock_mock: await async_pre_backup(hass) assert lock_mock.called @@ -24,7 +24,7 @@ async def test_async_pre_backup_with_timeout( ) -> None: """Test pre backup with timeout.""" with patch( - "homeassistant.components.recorder.backup.Recorder.lock_database", + "homeassistant.components.recorder.core.Recorder.lock_database", side_effect=TimeoutError(), ) as lock_mock, pytest.raises(TimeoutError): await async_pre_backup(hass) @@ -45,7 +45,7 @@ async def test_async_pre_backup_with_migration( async def test_async_post_backup(hass: HomeAssistant, recorder_mock) -> None: """Test post backup.""" with patch( - "homeassistant.components.recorder.backup.Recorder.unlock_database" + "homeassistant.components.recorder.core.Recorder.unlock_database" ) as unlock_mock: await async_post_backup(hass) assert unlock_mock.called @@ -54,7 +54,7 @@ async def test_async_post_backup(hass: HomeAssistant, recorder_mock) -> None: async def test_async_post_backup_failure(hass: HomeAssistant, recorder_mock) -> None: """Test post backup failure.""" with patch( - "homeassistant.components.recorder.backup.Recorder.unlock_database", + "homeassistant.components.recorder.core.Recorder.unlock_database", return_value=False, ) as unlock_mock, pytest.raises(HomeAssistantError): await async_post_backup(hass) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 3e25a54e39d..82444f86a05 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -24,7 +24,7 @@ from homeassistant.components.recorder import ( Recorder, get_instance, ) -from homeassistant.components.recorder.const import DATA_INSTANCE, KEEPALIVE_TIME +from homeassistant.components.recorder.const import KEEPALIVE_TIME from homeassistant.components.recorder.db_schema import ( SCHEMA_VERSION, EventData, @@ -100,13 +100,13 @@ async def test_shutdown_before_startup_finishes( } hass.state = CoreState.not_running - await async_setup_recorder_instance(hass, config) - await hass.data[DATA_INSTANCE].async_db_ready + instance = await async_setup_recorder_instance(hass, config) + await instance.async_db_ready await hass.async_block_till_done() - session = await hass.async_add_executor_job(hass.data[DATA_INSTANCE].get_session) + session = await hass.async_add_executor_job(instance.get_session) - with patch.object(hass.data[DATA_INSTANCE], "engine"): + with patch.object(instance, "engine"): hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() await hass.async_stop() @@ -214,14 +214,16 @@ async def test_saving_many_states( hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT ): """Test we expire after many commits.""" - await async_setup_recorder_instance(hass, {recorder.CONF_COMMIT_INTERVAL: 0}) + instance = await async_setup_recorder_instance( + hass, {recorder.CONF_COMMIT_INTERVAL: 0} + ) entity_id = "test.recorder" attributes = {"test_attr": 5, "test_attr_10": "nice"} - with patch.object( - hass.data[DATA_INSTANCE].event_session, "expire_all" - ) as expire_all, patch.object(recorder.core, "EXPIRE_AFTER_COMMITS", 2): + with patch.object(instance.event_session, "expire_all") as expire_all, patch.object( + recorder.core, "EXPIRE_AFTER_COMMITS", 2 + ): for _ in range(3): hass.states.async_set(entity_id, "on", attributes) await async_wait_recording_done(hass) @@ -269,14 +271,14 @@ def test_saving_state_with_exception(hass, hass_recorder, caplog): attributes = {"test_attr": 5, "test_attr_10": "nice"} def _throw_if_state_in_session(*args, **kwargs): - for obj in hass.data[DATA_INSTANCE].event_session: + for obj in get_instance(hass).event_session: if isinstance(obj, States): raise OperationalError( "insert the state", "fake params", "forced to fail" ) with patch("time.sleep"), patch.object( - hass.data[DATA_INSTANCE].event_session, + get_instance(hass).event_session, "flush", side_effect=_throw_if_state_in_session, ): @@ -307,14 +309,14 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog): attributes = {"test_attr": 5, "test_attr_10": "nice"} def _throw_if_state_in_session(*args, **kwargs): - for obj in hass.data[DATA_INSTANCE].event_session: + for obj in get_instance(hass).event_session: if isinstance(obj, States): raise SQLAlchemyError( "insert the state", "fake params", "forced to fail" ) with patch("time.sleep"), patch.object( - hass.data[DATA_INSTANCE].event_session, + get_instance(hass).event_session, "flush", side_effect=_throw_if_state_in_session, ): @@ -390,7 +392,7 @@ def test_saving_event(hass, hass_recorder): assert len(events) == 1 event: Event = events[0] - hass.data[DATA_INSTANCE].block_till_done() + get_instance(hass).block_till_done() events: list[Event] = [] with session_scope(hass=hass) as session: @@ -421,7 +423,7 @@ def test_saving_event(hass, hass_recorder): def test_saving_state_with_commit_interval_zero(hass_recorder): """Test saving a state with a commit interval of zero.""" hass = hass_recorder({"commit_interval": 0}) - assert hass.data[DATA_INSTANCE].commit_interval == 0 + get_instance(hass).commit_interval == 0 entity_id = "test.recorder" state = "restoring_from_db" @@ -690,7 +692,7 @@ def run_tasks_at_time(hass, test_time): """Advance the clock and wait for any callbacks to finish.""" fire_time_changed(hass, test_time) hass.block_till_done() - hass.data[DATA_INSTANCE].block_till_done() + get_instance(hass).block_till_done() @pytest.mark.parametrize("enable_nightly_purge", [True]) @@ -1258,7 +1260,7 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog): sqlite3_exception.__cause__ = sqlite3.DatabaseError() with patch.object( - hass.data[DATA_INSTANCE].event_session, + get_instance(hass).event_session, "close", side_effect=OperationalError("statement", {}, []), ): @@ -1267,7 +1269,7 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog): await async_wait_recording_done(hass) with patch.object( - hass.data[DATA_INSTANCE].event_session, + get_instance(hass).event_session, "commit", side_effect=[sqlite3_exception, None], ): @@ -1357,7 +1359,7 @@ async def test_database_lock_and_unlock( with session_scope(hass=hass) as session: return list(session.query(Events).filter_by(event_type=event_type)) - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) assert await instance.lock_database() @@ -1399,7 +1401,7 @@ async def test_database_lock_and_overflow( with session_scope(hass=hass) as session: return list(session.query(Events).filter_by(event_type=event_type)) - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) with patch.object(recorder.core, "MAX_QUEUE_BACKLOG", 1), patch.object( recorder.core, "DB_LOCK_QUEUE_CHECK_TIMEOUT", 0.1 @@ -1424,7 +1426,7 @@ async def test_database_lock_timeout(hass, recorder_mock): """Test locking database timeout when recorder stopped.""" hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) class BlockQueue(recorder.tasks.RecorderTask): event: threading.Event = threading.Event() @@ -1447,7 +1449,7 @@ async def test_database_lock_without_instance(hass, recorder_mock): """Test database lock doesn't fail if instance is not initialized.""" hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) - instance: Recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) with patch.object(instance, "engine", None): try: assert await instance.lock_database() diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 38d6a191809..a57bd246f8b 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -21,7 +21,6 @@ from sqlalchemy.pool import StaticPool from homeassistant.bootstrap import async_setup_component from homeassistant.components import persistent_notification as pn, recorder from homeassistant.components.recorder import db_schema, migration -from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.db_schema import ( SCHEMA_VERSION, RecorderRuns, @@ -82,7 +81,7 @@ async def test_migration_in_progress(hass): await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} ) - await hass.data[DATA_INSTANCE].async_migration_event.wait() + await recorder.get_instance(hass).async_migration_event.wait() assert recorder.util.async_migration_in_progress(hass) is True await async_wait_recording_done(hass) @@ -112,7 +111,7 @@ async def test_database_migration_failed(hass): hass.states.async_set("my.entity", "on", {}) hass.states.async_set("my.entity", "off", {}) await hass.async_block_till_done() - await hass.async_add_executor_job(hass.data[DATA_INSTANCE].join) + await hass.async_add_executor_job(recorder.get_instance(hass).join) await hass.async_block_till_done() assert recorder.util.async_migration_in_progress(hass) is False @@ -172,7 +171,7 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass): hass.states.async_set("my.entity", "on", {}) hass.states.async_set("my.entity", "off", {}) await hass.async_block_till_done() - await hass.async_add_executor_job(hass.data[DATA_INSTANCE].join) + await hass.async_add_executor_job(recorder.get_instance(hass).join) await hass.async_block_till_done() assert recorder.util.async_migration_in_progress(hass) is False @@ -201,7 +200,7 @@ async def test_events_during_migration_are_queued(hass): async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=2)) await hass.async_block_till_done() async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=4)) - await hass.data[DATA_INSTANCE].async_recorder_ready.wait() + await recorder.get_instance(hass).async_recorder_ready.wait() await async_wait_recording_done(hass) assert recorder.util.async_migration_in_progress(hass) is False @@ -232,7 +231,7 @@ async def test_events_during_migration_queue_exhausted(hass): async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=4)) await hass.async_block_till_done() hass.states.async_set("my.entity", "off", {}) - await hass.data[DATA_INSTANCE].async_recorder_ready.wait() + await recorder.get_instance(hass).async_recorder_ready.wait() await async_wait_recording_done(hass) assert recorder.util.async_migration_in_progress(hass) is False diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 30a2926844b..ee76b40a15b 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import Session from homeassistant.components import recorder from homeassistant.components.recorder import history, statistics -from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX +from homeassistant.components.recorder.const import SQLITE_URL_PREFIX from homeassistant.components.recorder.db_schema import StatisticsShortTerm from homeassistant.components.recorder.models import process_timestamp_to_utc_isoformat from homeassistant.components.recorder.statistics import ( @@ -45,7 +45,7 @@ ORIG_TZ = dt_util.DEFAULT_TIME_ZONE def test_compile_hourly_statistics(hass_recorder): """Test compiling hourly statistics.""" hass = hass_recorder() - recorder = hass.data[DATA_INSTANCE] + instance = recorder.get_instance(hass) setup_component(hass, "sensor", {}) zero, four, states = record_states(hass) hist = history.get_significant_states(hass, zero, four) @@ -142,7 +142,7 @@ def test_compile_hourly_statistics(hass_recorder): stats = get_last_short_term_statistics(hass, 1, "sensor.test3", True) assert stats == {} - recorder.get_session().query(StatisticsShortTerm).delete() + instance.get_session().query(StatisticsShortTerm).delete() # Should not fail there is nothing in the table stats = get_latest_short_term_statistics(hass, ["sensor.test1"]) assert stats == {} diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 8624719f951..ac4eeada3d3 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -13,7 +13,7 @@ from sqlalchemy.sql.lambdas import StatementLambdaElement from homeassistant.components import recorder from homeassistant.components.recorder import history, util -from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX +from homeassistant.components.recorder.const import SQLITE_URL_PREFIX from homeassistant.components.recorder.db_schema import RecorderRuns from homeassistant.components.recorder.models import UnsupportedDialect from homeassistant.components.recorder.util import ( @@ -35,7 +35,7 @@ def test_session_scope_not_setup(hass_recorder): """Try to create a session scope when not setup.""" hass = hass_recorder() with patch.object( - hass.data[DATA_INSTANCE], "get_session", return_value=None + util.get_instance(hass), "get_session", return_value=None ), pytest.raises(RuntimeError): with util.session_scope(hass=hass): pass @@ -547,7 +547,7 @@ def test_basic_sanity_check(hass_recorder): """Test the basic sanity checks with a missing table.""" hass = hass_recorder() - cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor() + cursor = util.get_instance(hass).engine.raw_connection().cursor() assert util.basic_sanity_check(cursor) is True @@ -560,7 +560,7 @@ def test_basic_sanity_check(hass_recorder): def test_combined_checks(hass_recorder, caplog): """Run Checks on the open database.""" hass = hass_recorder() - instance = recorder.get_instance(hass) + instance = util.get_instance(hass) instance.db_retry_wait = 0 cursor = instance.engine.raw_connection().cursor() @@ -639,8 +639,8 @@ def test_end_incomplete_runs(hass_recorder, caplog): def test_periodic_db_cleanups(hass_recorder): """Test periodic db cleanups.""" hass = hass_recorder() - with patch.object(hass.data[DATA_INSTANCE].engine, "connect") as connect_mock: - util.periodic_db_cleanups(hass.data[DATA_INSTANCE]) + with patch.object(util.get_instance(hass).engine, "connect") as connect_mock: + util.periodic_db_cleanups(util.get_instance(hass)) text_obj = connect_mock.return_value.__enter__.return_value.execute.mock_calls[0][ 1 @@ -663,11 +663,9 @@ async def test_write_lock_db( config = { recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db?timeout=0.1") } - await async_setup_recorder_instance(hass, config) + instance = await async_setup_recorder_instance(hass, config) await hass.async_block_till_done() - instance = hass.data[DATA_INSTANCE] - def _drop_table(): with instance.engine.connect() as connection: connection.execute(text("DROP TABLE events;")) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index a7ac01cf1d1..283883030fa 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -8,7 +8,6 @@ import pytest from pytest import approx from homeassistant.components import recorder -from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.statistics import ( async_add_external_statistics, get_last_statistics, @@ -304,7 +303,7 @@ async def test_recorder_info_bad_recorder_config(hass, hass_ws_client): await hass.async_block_till_done() # Wait for recorder to shut down - await hass.async_add_executor_job(hass.data[DATA_INSTANCE].join) + await hass.async_add_executor_job(recorder.get_instance(hass).join) await client.send_json({"id": 1, "type": "recorder/info"}) response = await client.receive_json() diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index 4be59e4c82c..cc2f9c76f1f 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -10,7 +10,6 @@ from pytest import approx from homeassistant import loader from homeassistant.components.recorder import history -from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.db_schema import StatisticsMeta from homeassistant.components.recorder.models import process_timestamp_to_utc_isoformat from homeassistant.components.recorder.statistics import ( @@ -18,7 +17,7 @@ from homeassistant.components.recorder.statistics import ( list_statistic_ids, statistics_during_period, ) -from homeassistant.components.recorder.util import session_scope +from homeassistant.components.recorder.util import get_instance, session_scope from homeassistant.const import STATE_UNAVAILABLE from homeassistant.setup import async_setup_component, setup_component import homeassistant.util.dt as dt_util @@ -2290,7 +2289,7 @@ def test_compile_statistics_hourly_daily_monthly_summary(hass_recorder, caplog): hass = hass_recorder() # Remove this after dropping the use of the hass_recorder fixture hass.config.set_time_zone("America/Regina") - recorder = hass.data[DATA_INSTANCE] + instance = get_instance(hass) setup_component(hass, "sensor", {}) wait_recording_done(hass) # Wait for the sensor recorder platform to be added attributes = { @@ -2454,7 +2453,7 @@ def test_compile_statistics_hourly_daily_monthly_summary(hass_recorder, caplog): sum_adjustement_start = zero + timedelta(minutes=65) for i in range(13, 24): expected_sums["sensor.test4"][i] += sum_adjustment - recorder.async_adjust_statistics( + instance.async_adjust_statistics( "sensor.test4", sum_adjustement_start, sum_adjustment ) wait_recording_done(hass)