Convert test helpers to get hass instance to contextmanagers (#109990)

* Convert get_test_home_assistant helper to contextmanager

* Convert async_test_home_assistant helper to contextmanager

* Move timezone reset to async_test_home_assistant helper
This commit is contained in:
Marc Mueller 2024-02-11 21:23:51 +01:00 committed by GitHub
parent 3342e6ddbd
commit 2ef2172b01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 750 additions and 784 deletions

View file

@ -3,8 +3,8 @@ from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Generator, Mapping, Sequence from collections.abc import AsyncGenerator, Generator, Mapping, Sequence
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
import functools as ft import functools as ft
@ -153,15 +153,17 @@ def get_test_config_dir(*add_path):
return os.path.join(os.path.dirname(__file__), "testing_config", *add_path) return os.path.join(os.path.dirname(__file__), "testing_config", *add_path)
def get_test_home_assistant(): @contextmanager
def get_test_home_assistant() -> Generator[HomeAssistant, None, None]:
"""Return a Home Assistant object pointing at test config directory.""" """Return a Home Assistant object pointing at test config directory."""
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
hass = loop.run_until_complete(async_test_home_assistant(loop)) context_manager = async_test_home_assistant(loop)
hass = loop.run_until_complete(context_manager.__aenter__())
loop_stop_event = threading.Event() loop_stop_event = threading.Event()
def run_loop(): def run_loop() -> None:
"""Run event loop.""" """Run event loop."""
loop._thread_ident = threading.get_ident() loop._thread_ident = threading.get_ident()
@ -171,25 +173,30 @@ def get_test_home_assistant():
orig_stop = hass.stop orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop) hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks): def start_hass(*mocks: Any) -> None:
"""Start hass.""" """Start hass."""
asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result() asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result()
def stop_hass(): def stop_hass() -> None:
"""Stop hass.""" """Stop hass."""
orig_stop() orig_stop()
loop_stop_event.wait() loop_stop_event.wait()
loop.close()
hass.start = start_hass hass.start = start_hass
hass.stop = stop_hass hass.stop = stop_hass
threading.Thread(name="LoopThread", target=run_loop, daemon=False).start() threading.Thread(name="LoopThread", target=run_loop, daemon=False).start()
return hass yield hass
loop.run_until_complete(context_manager.__aexit__(None, None, None))
loop.close()
async def async_test_home_assistant(event_loop, load_registries=True): @asynccontextmanager
async def async_test_home_assistant(
event_loop: asyncio.AbstractEventLoop | None = None,
load_registries: bool = True,
) -> AsyncGenerator[HomeAssistant, None]:
"""Return a Home Assistant object pointing at test config dir.""" """Return a Home Assistant object pointing at test config dir."""
hass = HomeAssistant(get_test_config_dir()) hass = HomeAssistant(get_test_config_dir())
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
@ -200,6 +207,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
orig_async_add_job = hass.async_add_job orig_async_add_job = hass.async_add_job
orig_async_add_executor_job = hass.async_add_executor_job orig_async_add_executor_job = hass.async_add_executor_job
orig_async_create_task = hass.async_create_task orig_async_create_task = hass.async_create_task
orig_tz = dt_util.DEFAULT_TIME_ZONE
def async_add_job(target, *args): def async_add_job(target, *args):
"""Add job.""" """Add job."""
@ -300,7 +308,10 @@ async def async_test_home_assistant(event_loop, load_registries=True):
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance)
return hass yield hass
# Restore timezone, it is set when creating the hass object
dt_util.DEFAULT_TIME_ZONE = orig_tz
def async_mock_service( def async_mock_service(

View file

@ -1,7 +1,6 @@
"""Tests for the Bluetooth integration.""" """Tests for the Bluetooth integration."""
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
import time import time
@ -1673,93 +1672,93 @@ async def test_integration_multiple_entity_platforms_with_reload_and_restart(
unregister_binary_sensor_processor() unregister_binary_sensor_processor()
unregister_sensor_processor() unregister_sensor_processor()
hass = await async_test_home_assistant(asyncio.get_running_loop()) async with async_test_home_assistant() as hass:
await async_setup_component(hass, DOMAIN, {DOMAIN: {}}) await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
current_entry.set(entry) current_entry.set(entry)
coordinator = PassiveBluetoothProcessorCoordinator( coordinator = PassiveBluetoothProcessorCoordinator(
hass, hass,
_LOGGER, _LOGGER,
"aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:ff",
BluetoothScanningMode.ACTIVE, BluetoothScanningMode.ACTIVE,
_mock_update_method, _mock_update_method,
) )
assert coordinator.available is False # no data yet assert coordinator.available is False # no data yet
mock_add_sensor_entities = MagicMock() mock_add_sensor_entities = MagicMock()
mock_add_binary_sensor_entities = MagicMock() mock_add_binary_sensor_entities = MagicMock()
binary_sensor_processor = PassiveBluetoothDataProcessor( binary_sensor_processor = PassiveBluetoothDataProcessor(
lambda service_info: DEVICE_ONLY_PASSIVE_BLUETOOTH_DATA_UPDATE, lambda service_info: DEVICE_ONLY_PASSIVE_BLUETOOTH_DATA_UPDATE,
BINARY_SENSOR_DOMAIN, BINARY_SENSOR_DOMAIN,
) )
sensor_processor = PassiveBluetoothDataProcessor( sensor_processor = PassiveBluetoothDataProcessor(
lambda service_info: DEVICE_ONLY_PASSIVE_BLUETOOTH_DATA_UPDATE, lambda service_info: DEVICE_ONLY_PASSIVE_BLUETOOTH_DATA_UPDATE,
SENSOR_DOMAIN, SENSOR_DOMAIN,
) )
sensor_processor.async_add_entities_listener( sensor_processor.async_add_entities_listener(
PassiveBluetoothProcessorEntity, PassiveBluetoothProcessorEntity,
mock_add_sensor_entities, mock_add_sensor_entities,
) )
binary_sensor_processor.async_add_entities_listener( binary_sensor_processor.async_add_entities_listener(
PassiveBluetoothProcessorEntity, PassiveBluetoothProcessorEntity,
mock_add_binary_sensor_entities, mock_add_binary_sensor_entities,
) )
unregister_binary_sensor_processor = coordinator.async_register_processor( unregister_binary_sensor_processor = coordinator.async_register_processor(
binary_sensor_processor, BinarySensorEntityDescription binary_sensor_processor, BinarySensorEntityDescription
) )
unregister_sensor_processor = coordinator.async_register_processor( unregister_sensor_processor = coordinator.async_register_processor(
sensor_processor, SensorEntityDescription sensor_processor, SensorEntityDescription
) )
cancel_coordinator = coordinator.async_start() cancel_coordinator = coordinator.async_start()
assert len(mock_add_binary_sensor_entities.mock_calls) == 1 assert len(mock_add_binary_sensor_entities.mock_calls) == 1
assert len(mock_add_sensor_entities.mock_calls) == 1 assert len(mock_add_sensor_entities.mock_calls) == 1
binary_sensor_entities = [ binary_sensor_entities = [
*mock_add_binary_sensor_entities.mock_calls[0][1][0], *mock_add_binary_sensor_entities.mock_calls[0][1][0],
] ]
sensor_entities = [ sensor_entities = [
*mock_add_sensor_entities.mock_calls[0][1][0], *mock_add_sensor_entities.mock_calls[0][1][0],
] ]
sensor_entity_one: PassiveBluetoothProcessorEntity = sensor_entities[0] sensor_entity_one: PassiveBluetoothProcessorEntity = sensor_entities[0]
sensor_entity_one.hass = hass sensor_entity_one.hass = hass
assert sensor_entity_one.available is False # service data not injected assert sensor_entity_one.available is False # service data not injected
assert sensor_entity_one.unique_id == "aa:bb:cc:dd:ee:ff-pressure" assert sensor_entity_one.unique_id == "aa:bb:cc:dd:ee:ff-pressure"
assert sensor_entity_one.device_info == { assert sensor_entity_one.device_info == {
"identifiers": {("bluetooth", "aa:bb:cc:dd:ee:ff")}, "identifiers": {("bluetooth", "aa:bb:cc:dd:ee:ff")},
"connections": {("bluetooth", "aa:bb:cc:dd:ee:ff")}, "connections": {("bluetooth", "aa:bb:cc:dd:ee:ff")},
"manufacturer": "Test Manufacturer", "manufacturer": "Test Manufacturer",
"model": "Test Model", "model": "Test Model",
"name": "Test Device", "name": "Test Device",
} }
assert sensor_entity_one.entity_key == PassiveBluetoothEntityKey( assert sensor_entity_one.entity_key == PassiveBluetoothEntityKey(
key="pressure", device_id=None key="pressure", device_id=None
) )
binary_sensor_entity_one: PassiveBluetoothProcessorEntity = binary_sensor_entities[ binary_sensor_entity_one: PassiveBluetoothProcessorEntity = (
0 binary_sensor_entities[0]
] )
binary_sensor_entity_one.hass = hass binary_sensor_entity_one.hass = hass
assert binary_sensor_entity_one.available is False # service data not injected assert binary_sensor_entity_one.available is False # service data not injected
assert binary_sensor_entity_one.unique_id == "aa:bb:cc:dd:ee:ff-motion" assert binary_sensor_entity_one.unique_id == "aa:bb:cc:dd:ee:ff-motion"
assert binary_sensor_entity_one.device_info == { assert binary_sensor_entity_one.device_info == {
"identifiers": {("bluetooth", "aa:bb:cc:dd:ee:ff")}, "identifiers": {("bluetooth", "aa:bb:cc:dd:ee:ff")},
"connections": {("bluetooth", "aa:bb:cc:dd:ee:ff")}, "connections": {("bluetooth", "aa:bb:cc:dd:ee:ff")},
"manufacturer": "Test Manufacturer", "manufacturer": "Test Manufacturer",
"model": "Test Model", "model": "Test Model",
"name": "Test Device", "name": "Test Device",
} }
assert binary_sensor_entity_one.entity_key == PassiveBluetoothEntityKey( assert binary_sensor_entity_one.entity_key == PassiveBluetoothEntityKey(
key="motion", device_id=None key="motion", device_id=None
) )
cancel_coordinator() cancel_coordinator()
unregister_binary_sensor_processor() unregister_binary_sensor_processor()
unregister_sensor_processor() unregister_sensor_processor()
await hass.async_stop() await hass.async_stop()
NAMING_PASSIVE_BLUETOOTH_DATA_UPDATE = PassiveBluetoothDataUpdate( NAMING_PASSIVE_BLUETOOTH_DATA_UPDATE = PassiveBluetoothDataUpdate(

View file

@ -75,32 +75,26 @@ class MockFFmpegDev(ffmpeg.FFmpegBase):
self.called_entities = entity_ids self.called_entities = entity_ids
class TestFFmpegSetup: def test_setup_component():
"""Test class for ffmpeg.""" """Set up ffmpeg component."""
with get_test_home_assistant() as hass:
def setup_method(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
def teardown_method(self):
"""Stop everything that was started."""
self.hass.stop()
def test_setup_component(self):
"""Set up ffmpeg component."""
with assert_setup_component(1): with assert_setup_component(1):
setup_component(self.hass, ffmpeg.DOMAIN, {ffmpeg.DOMAIN: {}}) setup_component(hass, ffmpeg.DOMAIN, {ffmpeg.DOMAIN: {}})
assert self.hass.data[ffmpeg.DATA_FFMPEG].binary == "ffmpeg" assert hass.data[ffmpeg.DATA_FFMPEG].binary == "ffmpeg"
hass.stop()
def test_setup_component_test_service(self):
"""Set up ffmpeg component test services.""" def test_setup_component_test_service():
"""Set up ffmpeg component test services."""
with get_test_home_assistant() as hass:
with assert_setup_component(1): with assert_setup_component(1):
setup_component(self.hass, ffmpeg.DOMAIN, {ffmpeg.DOMAIN: {}}) setup_component(hass, ffmpeg.DOMAIN, {ffmpeg.DOMAIN: {}})
assert self.hass.services.has_service(ffmpeg.DOMAIN, "start") assert hass.services.has_service(ffmpeg.DOMAIN, "start")
assert self.hass.services.has_service(ffmpeg.DOMAIN, "stop") assert hass.services.has_service(ffmpeg.DOMAIN, "stop")
assert self.hass.services.has_service(ffmpeg.DOMAIN, "restart") assert hass.services.has_service(ffmpeg.DOMAIN, "restart")
hass.stop()
async def test_setup_component_test_register(hass: HomeAssistant) -> None: async def test_setup_component_test_register(hass: HomeAssistant) -> None:

View file

@ -119,14 +119,19 @@ class TestComponentsCore(unittest.TestCase):
def setUp(self): def setUp(self):
"""Set up things to be run when tests are started.""" """Set up things to be run when tests are started."""
self.hass = get_test_home_assistant() self._manager = get_test_home_assistant()
self.hass = self._manager.__enter__()
assert asyncio.run_coroutine_threadsafe( assert asyncio.run_coroutine_threadsafe(
async_setup_component(self.hass, "homeassistant", {}), self.hass.loop async_setup_component(self.hass, "homeassistant", {}), self.hass.loop
).result() ).result()
self.hass.states.set("light.Bowl", STATE_ON) self.hass.states.set("light.Bowl", STATE_ON)
self.hass.states.set("light.Ceiling", STATE_OFF) self.hass.states.set("light.Ceiling", STATE_OFF)
self.addCleanup(self.hass.stop)
def tearDown(self) -> None:
"""Tear down hass object."""
self.hass.stop()
self._manager.__exit__(None, None, None)
def test_is_on(self): def test_is_on(self):
"""Test is_on method.""" """Test is_on method."""

View file

@ -99,7 +99,8 @@ class TestPicnicSensor(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
"""Set up things to be run when tests are started.""" """Set up things to be run when tests are started."""
self.hass = await async_test_home_assistant(None) self._manager = async_test_home_assistant()
self.hass = await self._manager.__aenter__()
self.entity_registry = er.async_get(self.hass) self.entity_registry = er.async_get(self.hass)
# Patch the api client # Patch the api client
@ -122,6 +123,7 @@ class TestPicnicSensor(unittest.IsolatedAsyncioTestCase):
async def asyncTearDown(self): async def asyncTearDown(self):
"""Tear down the test setup, stop hass/patchers.""" """Tear down the test setup, stop hass/patchers."""
await self.hass.async_stop(force=True) await self.hass.async_stop(force=True)
await self._manager.__aexit__(None, None, None)
self.picnic_patcher.stop() self.picnic_patcher.stop()
@property @property

View file

@ -27,8 +27,6 @@ from ...common import wait_recording_done
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
def test_delete_duplicates_no_duplicates( def test_delete_duplicates_no_duplicates(
hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture
@ -169,8 +167,7 @@ def test_delete_metadata_duplicates(
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch( ), patch(
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28 "homeassistant.components.recorder.core.create_engine", new=_create_engine_28
): ), get_test_home_assistant() as hass:
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass) wait_recording_done(hass)
@ -198,27 +195,25 @@ def test_delete_metadata_duplicates(
assert tmp[2].statistic_id == "test:fossil_percentage" assert tmp[2].statistic_id == "test:fossil_percentage"
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 28 # Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
assert "Deleted 1 duplicated statistics_meta rows" in caplog.text assert "Deleted 1 duplicated statistics_meta rows" in caplog.text
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
tmp = session.query(recorder.db_schema.StatisticsMeta).all() tmp = session.query(recorder.db_schema.StatisticsMeta).all()
assert len(tmp) == 2 assert len(tmp) == 2
assert tmp[0].id == 2 assert tmp[0].id == 2
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 3 assert tmp[1].id == 3
assert tmp[1].statistic_id == "test:fossil_percentage" assert tmp[1].statistic_id == "test:fossil_percentage"
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
def test_delete_metadata_duplicates_many( def test_delete_metadata_duplicates_many(
@ -264,8 +259,7 @@ def test_delete_metadata_duplicates_many(
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch( ), patch(
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28 "homeassistant.components.recorder.core.create_engine", new=_create_engine_28
): ), get_test_home_assistant() as hass:
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass) wait_recording_done(hass)
@ -295,29 +289,27 @@ def test_delete_metadata_duplicates_many(
) )
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 28 # Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
assert "Deleted 1102 duplicated statistics_meta rows" in caplog.text assert "Deleted 1102 duplicated statistics_meta rows" in caplog.text
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
tmp = session.query(recorder.db_schema.StatisticsMeta).all() tmp = session.query(recorder.db_schema.StatisticsMeta).all()
assert len(tmp) == 3 assert len(tmp) == 3
assert tmp[0].id == 1101 assert tmp[0].id == 1101
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 1103 assert tmp[1].id == 1103
assert tmp[1].statistic_id == "test:total_energy_import_tariff_2" assert tmp[1].statistic_id == "test:total_energy_import_tariff_2"
assert tmp[2].id == 1105 assert tmp[2].id == 1105
assert tmp[2].statistic_id == "test:fossil_percentage" assert tmp[2].statistic_id == "test:fossil_percentage"
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
def test_delete_metadata_duplicates_no_duplicates( def test_delete_metadata_duplicates_no_duplicates(

View file

@ -1301,22 +1301,22 @@ def test_compile_missing_statistics(
test_db_file = test_dir.joinpath("test_run_info.db") test_db_file = test_dir.joinpath("test_run_info.db")
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
statistics_runs = list(session.query(StatisticsRuns)) statistics_runs = list(session.query(StatisticsRuns))
assert len(statistics_runs) == 1 assert len(statistics_runs) == 1
last_run = process_timestamp(statistics_runs[0].start) last_run = process_timestamp(statistics_runs[0].start)
assert last_run == now - timedelta(minutes=5) assert last_run == now - timedelta(minutes=5)
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
# Start Home Assistant one hour later # Start Home Assistant one hour later
stats_5min = [] stats_5min = []
@ -1332,32 +1332,33 @@ def test_compile_missing_statistics(
stats_hourly.append(event) stats_hourly.append(event)
freezer.tick(timedelta(hours=1)) freezer.tick(timedelta(hours=1))
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
hass.bus.listen( hass.bus.listen(
EVENT_RECORDER_5MIN_STATISTICS_GENERATED, async_5min_stats_updated_listener EVENT_RECORDER_5MIN_STATISTICS_GENERATED, async_5min_stats_updated_listener
) )
hass.bus.listen( hass.bus.listen(
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, async_hourly_stats_updated_listener EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
) async_hourly_stats_updated_listener,
)
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
statistics_runs = list(session.query(StatisticsRuns)) statistics_runs = list(session.query(StatisticsRuns))
assert len(statistics_runs) == 13 # 12 5-minute runs assert len(statistics_runs) == 13 # 12 5-minute runs
last_run = process_timestamp(statistics_runs[1].start) last_run = process_timestamp(statistics_runs[1].start)
assert last_run == now assert last_run == now
assert len(stats_5min) == 1 assert len(stats_5min) == 1
assert len(stats_hourly) == 1 assert len(stats_hourly) == 1
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
def test_saving_sets_old_state(hass_recorder: Callable[..., HomeAssistant]) -> None: def test_saving_sets_old_state(hass_recorder: Callable[..., HomeAssistant]) -> None:
@ -1562,43 +1563,43 @@ def test_service_disable_run_information_recorded(tmp_path: Path) -> None:
test_db_file = test_dir.joinpath("test_run_info.db") test_db_file = test_dir.joinpath("test_run_info.db")
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
db_run_info = list(session.query(RecorderRuns)) db_run_info = list(session.query(RecorderRuns))
assert len(db_run_info) == 1 assert len(db_run_info) == 1
assert db_run_info[0].start is not None assert db_run_info[0].start is not None
assert db_run_info[0].end is None assert db_run_info[0].end is None
hass.services.call( hass.services.call(
DOMAIN, DOMAIN,
SERVICE_DISABLE, SERVICE_DISABLE,
{}, {},
blocking=True, blocking=True,
) )
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
db_run_info = list(session.query(RecorderRuns)) db_run_info = list(session.query(RecorderRuns))
assert len(db_run_info) == 2 assert len(db_run_info) == 2
assert db_run_info[0].start is not None assert db_run_info[0].start is not None
assert db_run_info[0].end is not None assert db_run_info[0].end is not None
assert db_run_info[1].start is not None assert db_run_info[1].start is not None
assert db_run_info[1].end is None assert db_run_info[1].end is None
hass.stop() hass.stop()
class CannotSerializeMe: class CannotSerializeMe:

View file

@ -36,8 +36,6 @@ from .common import async_wait_recording_done, create_engine_test
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
def _get_native_states(hass, entity_id): def _get_native_states(hass, entity_id):
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:

View file

@ -44,7 +44,6 @@ from tests.typing import RecorderInstanceGenerator
CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine"
SCHEMA_MODULE = "tests.components.recorder.db_schema_32" SCHEMA_MODULE = "tests.components.recorder.db_schema_32"
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
async def _async_wait_migration_done(hass: HomeAssistant) -> None: async def _async_wait_migration_done(hass: HomeAssistant) -> None:

View file

@ -49,8 +49,6 @@ from .common import (
from tests.common import mock_registry from tests.common import mock_registry
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
def test_converters_align_with_sensor() -> None: def test_converters_align_with_sensor() -> None:
"""Ensure STATISTIC_UNIT_TO_UNIT_CONVERTER is aligned with UNIT_CONVERTERS.""" """Ensure STATISTIC_UNIT_TO_UNIT_CONVERTER is aligned with UNIT_CONVERTERS."""

View file

@ -28,8 +28,6 @@ from .common import (
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
SCHEMA_VERSION_POSTFIX = "23_with_newer_columns" SCHEMA_VERSION_POSTFIX = "23_with_newer_columns"
SCHEMA_MODULE = get_schema_module_path(SCHEMA_VERSION_POSTFIX) SCHEMA_MODULE = get_schema_module_path(SCHEMA_VERSION_POSTFIX)
@ -169,8 +167,7 @@ def test_delete_duplicates(caplog: pytest.LogCaptureFixture, tmp_path: Path) ->
create_engine_test_for_schema_version_postfix, create_engine_test_for_schema_version_postfix,
schema_version_postfix=SCHEMA_VERSION_POSTFIX, schema_version_postfix=SCHEMA_VERSION_POSTFIX,
), ),
): ), get_test_home_assistant() as hass:
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass) wait_recording_done(hass)
@ -195,17 +192,15 @@ def test_delete_duplicates(caplog: pytest.LogCaptureFixture, tmp_path: Path) ->
session.add(recorder.db_schema.Statistics.from_stats(3, stat)) session.add(recorder.db_schema.Statistics.from_stats(3, stat))
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 23 # Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
assert "Deleted 2 duplicated statistics rows" in caplog.text assert "Deleted 2 duplicated statistics rows" in caplog.text
assert "Found non identical" not in caplog.text assert "Found non identical" not in caplog.text
@ -349,8 +344,7 @@ def test_delete_duplicates_many(
create_engine_test_for_schema_version_postfix, create_engine_test_for_schema_version_postfix,
schema_version_postfix=SCHEMA_VERSION_POSTFIX, schema_version_postfix=SCHEMA_VERSION_POSTFIX,
), ),
): ), get_test_home_assistant() as hass:
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass) wait_recording_done(hass)
@ -381,17 +375,15 @@ def test_delete_duplicates_many(
session.add(recorder.db_schema.Statistics.from_stats(3, stat)) session.add(recorder.db_schema.Statistics.from_stats(3, stat))
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 23 # Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
assert "Deleted 3002 duplicated statistics rows" in caplog.text assert "Deleted 3002 duplicated statistics rows" in caplog.text
assert "Found non identical" not in caplog.text assert "Found non identical" not in caplog.text
@ -506,8 +498,7 @@ def test_delete_duplicates_non_identical(
create_engine_test_for_schema_version_postfix, create_engine_test_for_schema_version_postfix,
schema_version_postfix=SCHEMA_VERSION_POSTFIX, schema_version_postfix=SCHEMA_VERSION_POSTFIX,
), ),
): ), get_test_home_assistant() as hass:
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass) wait_recording_done(hass)
@ -527,18 +518,16 @@ def test_delete_duplicates_non_identical(
session.add(recorder.db_schema.Statistics.from_stats(2, stat)) session.add(recorder.db_schema.Statistics.from_stats(2, stat))
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 23 # Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
hass.config.config_dir = tmp_path hass.config.config_dir = tmp_path
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
assert "Deleted 2 duplicated statistics rows" in caplog.text assert "Deleted 2 duplicated statistics rows" in caplog.text
assert "Deleted 1 non identical" in caplog.text assert "Deleted 1 non identical" in caplog.text
@ -618,8 +607,7 @@ def test_delete_duplicates_short_term(
create_engine_test_for_schema_version_postfix, create_engine_test_for_schema_version_postfix,
schema_version_postfix=SCHEMA_VERSION_POSTFIX, schema_version_postfix=SCHEMA_VERSION_POSTFIX,
), ),
): ), get_test_home_assistant() as hass:
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass) wait_recording_done(hass)
@ -638,18 +626,16 @@ def test_delete_duplicates_short_term(
) )
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 23 # Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant() with get_test_home_assistant() as hass:
hass.config.config_dir = tmp_path hass.config.config_dir = tmp_path
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.stop() hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
assert "duplicated statistics rows" not in caplog.text assert "duplicated statistics rows" not in caplog.text
assert "Found non identical" not in caplog.text assert "Found non identical" not in caplog.text

View file

@ -106,60 +106,59 @@ async def test_last_run_was_recently_clean(
recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db"), recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db"),
recorder.CONF_COMMIT_INTERVAL: 1, recorder.CONF_COMMIT_INTERVAL: 1,
} }
hass = await async_test_home_assistant(None) async with async_test_home_assistant() as hass:
return_values = []
real_last_run_was_recently_clean = util.last_run_was_recently_clean
return_values = [] def _last_run_was_recently_clean(cursor):
real_last_run_was_recently_clean = util.last_run_was_recently_clean return_values.append(real_last_run_was_recently_clean(cursor))
return return_values[-1]
def _last_run_was_recently_clean(cursor): # Test last_run_was_recently_clean is not called on new DB
return_values.append(real_last_run_was_recently_clean(cursor)) with patch(
return return_values[-1] "homeassistant.components.recorder.util.last_run_was_recently_clean",
wraps=_last_run_was_recently_clean,
) as last_run_was_recently_clean_mock:
await async_setup_recorder_instance(hass, config)
await hass.async_block_till_done()
last_run_was_recently_clean_mock.assert_not_called()
# Test last_run_was_recently_clean is not called on new DB # Restart HA, last_run_was_recently_clean should return True
with patch( hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
"homeassistant.components.recorder.util.last_run_was_recently_clean",
wraps=_last_run_was_recently_clean,
) as last_run_was_recently_clean_mock:
await async_setup_recorder_instance(hass, config)
await hass.async_block_till_done() await hass.async_block_till_done()
last_run_was_recently_clean_mock.assert_not_called() await hass.async_stop()
# Restart HA, last_run_was_recently_clean should return True async with async_test_home_assistant() as hass:
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) with patch(
await hass.async_block_till_done() "homeassistant.components.recorder.util.last_run_was_recently_clean",
await hass.async_stop() wraps=_last_run_was_recently_clean,
) as last_run_was_recently_clean_mock:
await async_setup_recorder_instance(hass, config)
last_run_was_recently_clean_mock.assert_called_once()
assert return_values[-1] is True
with patch( # Restart HA with a long downtime, last_run_was_recently_clean should return False
"homeassistant.components.recorder.util.last_run_was_recently_clean", hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
wraps=_last_run_was_recently_clean, await hass.async_block_till_done()
) as last_run_was_recently_clean_mock: await hass.async_stop()
hass = await async_test_home_assistant(None)
await async_setup_recorder_instance(hass, config)
last_run_was_recently_clean_mock.assert_called_once()
assert return_values[-1] is True
# Restart HA with a long downtime, last_run_was_recently_clean should return False
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
await hass.async_stop()
thirty_min_future_time = dt_util.utcnow() + timedelta(minutes=30) thirty_min_future_time = dt_util.utcnow() + timedelta(minutes=30)
with patch( async with async_test_home_assistant() as hass:
"homeassistant.components.recorder.util.last_run_was_recently_clean", with patch(
wraps=_last_run_was_recently_clean, "homeassistant.components.recorder.util.last_run_was_recently_clean",
) as last_run_was_recently_clean_mock, patch( wraps=_last_run_was_recently_clean,
"homeassistant.components.recorder.core.dt_util.utcnow", ) as last_run_was_recently_clean_mock, patch(
return_value=thirty_min_future_time, "homeassistant.components.recorder.core.dt_util.utcnow",
): return_value=thirty_min_future_time,
hass = await async_test_home_assistant(None) ):
await async_setup_recorder_instance(hass, config) await async_setup_recorder_instance(hass, config)
last_run_was_recently_clean_mock.assert_called_once() last_run_was_recently_clean_mock.assert_called_once()
assert return_values[-1] is False assert return_values[-1] is False
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_stop() await hass.async_stop()
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -1,5 +1,4 @@
"""The tests for recorder platform migrating data from v30.""" """The tests for recorder platform migrating data from v30."""
import asyncio
from datetime import timedelta from datetime import timedelta
import importlib import importlib
from pathlib import Path from pathlib import Path
@ -23,8 +22,6 @@ from .common import async_wait_recording_done
from tests.common import async_test_home_assistant from tests.common import async_test_home_assistant
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine"
SCHEMA_MODULE = "tests.components.recorder.db_schema_32" SCHEMA_MODULE = "tests.components.recorder.db_schema_32"
@ -115,108 +112,105 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmp_path: Path) -
), patch( ), patch(
"homeassistant.components.recorder.Recorder._cleanup_legacy_states_event_ids" "homeassistant.components.recorder.Recorder._cleanup_legacy_states_event_ids"
): ):
hass = await async_test_home_assistant(asyncio.get_running_loop()) async with async_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}}
)
await hass.async_block_till_done()
await async_wait_recording_done(hass)
await async_wait_recording_done(hass)
def _add_data():
with session_scope(hass=hass) as session:
session.add(old_db_schema.Events.from_event(custom_event))
session.add(old_db_schema.States.from_event(state_changed_event))
await recorder.get_instance(hass).async_add_executor_job(_add_data)
await hass.async_block_till_done()
await recorder.get_instance(hass).async_block_till_done()
states_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_states_index_names
)
states_index_names = {index["name"] for index in states_indexes}
assert recorder.get_instance(hass).use_legacy_events_index is True
await hass.async_stop()
await hass.async_block_till_done()
assert "ix_states_event_id" in states_index_names
# Test that the duplicates are removed during migration from schema 23
async with async_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component( assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}} hass, "recorder", {"recorder": {"db_url": dburl}}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await async_wait_recording_done(hass)
await async_wait_recording_done(hass)
def _add_data(): # We need to wait for all the migration tasks to complete
# before we can check the database.
for _ in range(number_of_migrations):
await recorder.get_instance(hass).async_block_till_done()
await async_wait_recording_done(hass)
def _get_test_data_from_db():
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
session.add(old_db_schema.Events.from_event(custom_event)) events_result = list(
session.add(old_db_schema.States.from_event(state_changed_event)) session.query(recorder.db_schema.Events).filter(
recorder.db_schema.Events.event_type_id.in_(
select_event_type_ids(("custom_event",))
)
)
)
states_result = list(
session.query(recorder.db_schema.States)
.join(
recorder.db_schema.StatesMeta,
recorder.db_schema.States.metadata_id
== recorder.db_schema.StatesMeta.metadata_id,
)
.where(recorder.db_schema.StatesMeta.entity_id == "sensor.test")
)
session.expunge_all()
return events_result, states_result
await recorder.get_instance(hass).async_add_executor_job(_add_data) events_result, states_result = await recorder.get_instance(
await hass.async_block_till_done() hass
await recorder.get_instance(hass).async_block_till_done() ).async_add_executor_job(_get_test_data_from_db)
assert len(events_result) == 1
assert events_result[0].time_fired_ts == now_timestamp
assert len(states_result) == 1
assert states_result[0].last_changed_ts == one_second_past_timestamp
assert states_result[0].last_updated_ts == now_timestamp
def _get_events_index_names():
with session_scope(hass=hass) as session:
return inspect(session.connection()).get_indexes("events")
events_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_events_index_names
)
events_index_names = {index["name"] for index in events_indexes}
assert "ix_events_context_id_bin" in events_index_names
assert "ix_events_context_id" not in events_index_names
states_indexes = await recorder.get_instance(hass).async_add_executor_job( states_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_states_index_names _get_states_index_names
) )
states_index_names = {index["name"] for index in states_indexes} states_index_names = {index["name"] for index in states_indexes}
assert recorder.get_instance(hass).use_legacy_events_index is True
# sqlite does not support dropping foreign keys so the
# ix_states_event_id index is not dropped in this case
# but use_legacy_events_index is still False
assert "ix_states_event_id" in states_index_names
assert recorder.get_instance(hass).use_legacy_events_index is False
await hass.async_stop() await hass.async_stop()
await hass.async_block_till_done()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
assert "ix_states_event_id" in states_index_names
# Test that the duplicates are removed during migration from schema 23
hass = await async_test_home_assistant(asyncio.get_running_loop())
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}}
)
await hass.async_block_till_done()
# We need to wait for all the migration tasks to complete
# before we can check the database.
for _ in range(number_of_migrations):
await recorder.get_instance(hass).async_block_till_done()
await async_wait_recording_done(hass)
def _get_test_data_from_db():
with session_scope(hass=hass) as session:
events_result = list(
session.query(recorder.db_schema.Events).filter(
recorder.db_schema.Events.event_type_id.in_(
select_event_type_ids(("custom_event",))
)
)
)
states_result = list(
session.query(recorder.db_schema.States)
.join(
recorder.db_schema.StatesMeta,
recorder.db_schema.States.metadata_id
== recorder.db_schema.StatesMeta.metadata_id,
)
.where(recorder.db_schema.StatesMeta.entity_id == "sensor.test")
)
session.expunge_all()
return events_result, states_result
events_result, states_result = await recorder.get_instance(
hass
).async_add_executor_job(_get_test_data_from_db)
assert len(events_result) == 1
assert events_result[0].time_fired_ts == now_timestamp
assert len(states_result) == 1
assert states_result[0].last_changed_ts == one_second_past_timestamp
assert states_result[0].last_updated_ts == now_timestamp
def _get_events_index_names():
with session_scope(hass=hass) as session:
return inspect(session.connection()).get_indexes("events")
events_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_events_index_names
)
events_index_names = {index["name"] for index in events_indexes}
assert "ix_events_context_id_bin" in events_index_names
assert "ix_events_context_id" not in events_index_names
states_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_states_index_names
)
states_index_names = {index["name"] for index in states_indexes}
# sqlite does not support dropping foreign keys so the
# ix_states_event_id index is not dropped in this case
# but use_legacy_events_index is still False
assert "ix_states_event_id" in states_index_names
assert recorder.get_instance(hass).use_legacy_events_index is False
await hass.async_stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
async def test_migrate_can_resume_entity_id_post_migration( async def test_migrate_can_resume_entity_id_post_migration(
@ -282,38 +276,60 @@ async def test_migrate_can_resume_entity_id_post_migration(
), patch( ), patch(
"homeassistant.components.recorder.Recorder._cleanup_legacy_states_event_ids" "homeassistant.components.recorder.Recorder._cleanup_legacy_states_event_ids"
): ):
hass = await async_test_home_assistant(asyncio.get_running_loop()) async with async_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component( assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}} hass, "recorder", {"recorder": {"db_url": dburl}}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
def _add_data(): def _add_data():
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
session.add(old_db_schema.Events.from_event(custom_event)) session.add(old_db_schema.Events.from_event(custom_event))
session.add(old_db_schema.States.from_event(state_changed_event)) session.add(old_db_schema.States.from_event(state_changed_event))
await recorder.get_instance(hass).async_add_executor_job(_add_data) await recorder.get_instance(hass).async_add_executor_job(_add_data)
await hass.async_block_till_done() await hass.async_block_till_done()
await recorder.get_instance(hass).async_block_till_done() await recorder.get_instance(hass).async_block_till_done()
states_indexes = await recorder.get_instance(hass).async_add_executor_job( states_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_states_index_names _get_states_index_names
) )
states_index_names = {index["name"] for index in states_indexes} states_index_names = {index["name"] for index in states_indexes}
assert recorder.get_instance(hass).use_legacy_events_index is True assert recorder.get_instance(hass).use_legacy_events_index is True
await hass.async_stop() await hass.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert "ix_states_event_id" in states_index_names assert "ix_states_event_id" in states_index_names
assert "ix_states_entity_id_last_updated_ts" in states_index_names assert "ix_states_entity_id_last_updated_ts" in states_index_names
with patch("homeassistant.components.recorder.Recorder._post_migrate_entity_ids"): with patch("homeassistant.components.recorder.Recorder._post_migrate_entity_ids"):
hass = await async_test_home_assistant(asyncio.get_running_loop()) async with async_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}}
)
await hass.async_block_till_done()
# We need to wait for all the migration tasks to complete
# before we can check the database.
for _ in range(number_of_migrations):
await recorder.get_instance(hass).async_block_till_done()
await async_wait_recording_done(hass)
states_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_states_index_names
)
states_index_names = {index["name"] for index in states_indexes}
await hass.async_stop()
await hass.async_block_till_done()
assert "ix_states_entity_id_last_updated_ts" in states_index_names
async with async_test_home_assistant() as hass:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component( assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}} hass, "recorder", {"recorder": {"db_url": dburl}}
@ -330,29 +346,6 @@ async def test_migrate_can_resume_entity_id_post_migration(
_get_states_index_names _get_states_index_names
) )
states_index_names = {index["name"] for index in states_indexes} states_index_names = {index["name"] for index in states_indexes}
assert "ix_states_entity_id_last_updated_ts" not in states_index_names
await hass.async_stop() await hass.async_stop()
await hass.async_block_till_done()
assert "ix_states_entity_id_last_updated_ts" in states_index_names
hass = await async_test_home_assistant(asyncio.get_running_loop())
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass, "recorder", {"recorder": {"db_url": dburl}}
)
await hass.async_block_till_done()
# We need to wait for all the migration tasks to complete
# before we can check the database.
for _ in range(number_of_migrations):
await recorder.get_instance(hass).async_block_till_done()
await async_wait_recording_done(hass)
states_indexes = await recorder.get_instance(hass).async_add_executor_job(
_get_states_index_names
)
states_index_names = {index["name"] for index in states_indexes}
assert "ix_states_entity_id_last_updated_ts" not in states_index_names
await hass.async_stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ

View file

@ -12,7 +12,7 @@ from homeassistant.components.recorder.statistics import (
statistics_during_period, statistics_during_period,
) )
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.core import CoreState, HomeAssistant from homeassistant.core import CoreState
from homeassistant.helpers import recorder as recorder_helper from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -51,74 +51,76 @@ def test_compile_missing_statistics(
three_days_ago = datetime(2021, 1, 1, 0, 0, 0, tzinfo=dt_util.UTC) three_days_ago = datetime(2021, 1, 1, 0, 0, 0, tzinfo=dt_util.UTC)
start_time = three_days_ago + timedelta(days=3) start_time = three_days_ago + timedelta(days=3)
freezer.move_to(three_days_ago) freezer.move_to(three_days_ago)
hass: HomeAssistant = get_test_home_assistant() with get_test_home_assistant() as hass:
hass.set_state(CoreState.not_running) hass.set_state(CoreState.not_running)
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "sensor", {}) setup_component(hass, "sensor", {})
setup_component(hass, "recorder", {"recorder": config}) setup_component(hass, "recorder", {"recorder": config})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES) hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES)
wait_recording_done(hass) wait_recording_done(hass)
two_days_ago = three_days_ago + timedelta(days=1) two_days_ago = three_days_ago + timedelta(days=1)
freezer.move_to(two_days_ago) freezer.move_to(two_days_ago)
do_adhoc_statistics(hass, start=two_days_ago) do_adhoc_statistics(hass, start=two_days_ago)
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
latest = get_latest_short_term_statistics_with_session( latest = get_latest_short_term_statistics_with_session(
hass, session, {"sensor.test1"}, {"state", "sum"} hass, session, {"sensor.test1"}, {"state", "sum"}
)
latest_stat = latest["sensor.test1"][0]
assert latest_stat["start"] == 1609545600.0
assert latest_stat["end"] == 1609545600.0 + 300
count = 1
past_time = two_days_ago
while past_time <= start_time:
freezer.move_to(past_time)
hass.states.set("sensor.test1", str(count), POWER_SENSOR_ATTRIBUTES)
past_time += timedelta(minutes=5)
count += 1
wait_recording_done(hass)
states = get_significant_states(
hass, three_days_ago, past_time, ["sensor.test1"]
) )
latest_stat = latest["sensor.test1"][0] assert len(states["sensor.test1"]) == 577
assert latest_stat["start"] == 1609545600.0
assert latest_stat["end"] == 1609545600.0 + 300
count = 1
past_time = two_days_ago
while past_time <= start_time:
freezer.move_to(past_time)
hass.states.set("sensor.test1", str(count), POWER_SENSOR_ATTRIBUTES)
past_time += timedelta(minutes=5)
count += 1
wait_recording_done(hass) hass.stop()
states = get_significant_states(hass, three_days_ago, past_time, ["sensor.test1"])
assert len(states["sensor.test1"]) == 577
hass.stop()
freezer.move_to(start_time) freezer.move_to(start_time)
hass: HomeAssistant = get_test_home_assistant() with get_test_home_assistant() as hass:
hass.set_state(CoreState.not_running) hass.set_state(CoreState.not_running)
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "sensor", {}) setup_component(hass, "sensor", {})
hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES) hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES)
setup_component(hass, "recorder", {"recorder": config}) setup_component(hass, "recorder", {"recorder": config})
hass.start() hass.start()
wait_recording_done(hass) wait_recording_done(hass)
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
latest = get_latest_short_term_statistics_with_session( latest = get_latest_short_term_statistics_with_session(
hass, session, {"sensor.test1"}, {"state", "sum", "max", "mean", "min"} hass, session, {"sensor.test1"}, {"state", "sum", "max", "mean", "min"}
)
latest_stat = latest["sensor.test1"][0]
assert latest_stat["start"] == 1609718100.0
assert latest_stat["end"] == 1609718100.0 + 300
assert latest_stat["mean"] == 576.0
assert latest_stat["min"] == 575.0
assert latest_stat["max"] == 576.0
stats = statistics_during_period(
hass,
two_days_ago,
start_time,
units={"energy": "kWh"},
statistic_ids={"sensor.test1"},
period="hour",
types={"mean"},
) )
latest_stat = latest["sensor.test1"][0] # Make sure we have 48 hours of statistics
assert latest_stat["start"] == 1609718100.0 assert len(stats["sensor.test1"]) == 48
assert latest_stat["end"] == 1609718100.0 + 300 # Make sure the last mean is 570.5
assert latest_stat["mean"] == 576.0 assert stats["sensor.test1"][-1]["mean"] == 570.5
assert latest_stat["min"] == 575.0 hass.stop()
assert latest_stat["max"] == 576.0
stats = statistics_during_period(
hass,
two_days_ago,
start_time,
units={"energy": "kWh"},
statistic_ids={"sensor.test1"},
period="hour",
types={"mean"},
)
# Make sure we have 48 hours of statistics
assert len(stats["sensor.test1"]) == 48
# Make sure the last mean is 570.5
assert stats["sensor.test1"][-1]["mean"] == 570.5
hass.stop()

View file

@ -60,7 +60,7 @@ from homeassistant.helpers import (
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import BASE_PLATFORMS, async_setup_component from homeassistant.setup import BASE_PLATFORMS, async_setup_component
from homeassistant.util import dt as dt_util, location from homeassistant.util import location
from homeassistant.util.json import json_loads from homeassistant.util.json import json_loads
from .ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS from .ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS
@ -526,8 +526,6 @@ async def hass(
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
hass_fixture_setup.append(True) hass_fixture_setup.append(True)
orig_tz = dt_util.DEFAULT_TIME_ZONE
def exc_handle(loop, context): def exc_handle(loop, context):
"""Handle exceptions by rethrowing them, which will fail the test.""" """Handle exceptions by rethrowing them, which will fail the test."""
# Most of these contexts will contain an exception, but not all. # Most of these contexts will contain an exception, but not all.
@ -545,26 +543,22 @@ async def hass(
orig_exception_handler(loop, context) orig_exception_handler(loop, context)
exceptions: list[Exception] = [] exceptions: list[Exception] = []
hass = await async_test_home_assistant(loop, load_registries) async with async_test_home_assistant(loop, load_registries) as hass:
orig_exception_handler = loop.get_exception_handler()
loop.set_exception_handler(exc_handle)
orig_exception_handler = loop.get_exception_handler() yield hass
loop.set_exception_handler(exc_handle)
yield hass # Config entries are not normally unloaded on HA shutdown. They are unloaded here
# to ensure that they could, and to help track lingering tasks and timers.
# Config entries are not normally unloaded on HA shutdown. They are unloaded here await asyncio.gather(
# to ensure that they could, and to help track lingering tasks and timers. *(
await asyncio.gather( config_entry.async_unload(hass)
*( for config_entry in hass.config_entries.async_entries()
config_entry.async_unload(hass) )
for config_entry in hass.config_entries.async_entries()
) )
)
await hass.async_stop(force=True) await hass.async_stop(force=True)
# Restore timezone, it is set when creating the hass object
dt_util.DEFAULT_TIME_ZONE = orig_tz
for ex in exceptions: for ex in exceptions:
if ( if (
@ -1305,86 +1299,85 @@ def hass_recorder(
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
from homeassistant.components.recorder import migration from homeassistant.components.recorder import migration
original_tz = dt_util.DEFAULT_TIME_ZONE with get_test_home_assistant() as hass:
nightly = (
recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
)
stats = (
recorder.Recorder.async_periodic_statistics if enable_statistics else None
)
compile_missing = (
recorder.Recorder._schedule_compile_missing_statistics
if enable_statistics
else None
)
schema_validate = (
migration._find_schema_errors
if enable_schema_validation
else itertools.repeat(set())
)
migrate_states_context_ids = (
recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
)
migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids
if enable_migrate_event_type_ids
else None
)
migrate_entity_ids = (
recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None
)
with patch(
"homeassistant.components.recorder.Recorder.async_nightly_tasks",
side_effect=nightly,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder.async_periodic_statistics",
side_effect=stats,
autospec=True,
), patch(
"homeassistant.components.recorder.migration._find_schema_errors",
side_effect=schema_validate,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_events_context_ids",
side_effect=migrate_events_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids",
side_effect=migrate_event_type_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_entity_ids",
side_effect=migrate_entity_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
side_effect=compile_missing,
autospec=True,
):
hass = get_test_home_assistant() def setup_recorder(config: dict[str, Any] | None = None) -> HomeAssistant:
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None """Set up with params."""
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None init_recorder_component(hass, config, recorder_db_url)
compile_missing = ( hass.start()
recorder.Recorder._schedule_compile_missing_statistics hass.block_till_done()
if enable_statistics hass.data[recorder.DATA_INSTANCE].block_till_done()
else None return hass
)
schema_validate = (
migration._find_schema_errors
if enable_schema_validation
else itertools.repeat(set())
)
migrate_states_context_ids = (
recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
)
migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids
if enable_migrate_event_type_ids
else None
)
migrate_entity_ids = (
recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None
)
with patch(
"homeassistant.components.recorder.Recorder.async_nightly_tasks",
side_effect=nightly,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder.async_periodic_statistics",
side_effect=stats,
autospec=True,
), patch(
"homeassistant.components.recorder.migration._find_schema_errors",
side_effect=schema_validate,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_events_context_ids",
side_effect=migrate_events_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids",
side_effect=migrate_event_type_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_entity_ids",
side_effect=migrate_entity_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
side_effect=compile_missing,
autospec=True,
):
def setup_recorder(config: dict[str, Any] | None = None) -> HomeAssistant: yield setup_recorder
"""Set up with params.""" hass.stop()
init_recorder_component(hass, config, recorder_db_url)
hass.start()
hass.block_till_done()
hass.data[recorder.DATA_INSTANCE].block_till_done()
return hass
yield setup_recorder
hass.stop()
# Restore timezone, it is set when creating the hass object
dt_util.DEFAULT_TIME_ZONE = original_tz
async def _async_init_recorder_component( async def _async_init_recorder_component(

View file

@ -516,199 +516,192 @@ async def test_changing_delayed_written_data(
async def test_saving_load_round_trip(tmpdir: py.path.local) -> None: async def test_saving_load_round_trip(tmpdir: py.path.local) -> None:
"""Test saving and loading round trip.""" """Test saving and loading round trip."""
loop = asyncio.get_running_loop() async with async_test_home_assistant() as hass:
hass = await async_test_home_assistant(loop) hass.config.config_dir = await hass.async_add_executor_job(
tmpdir.mkdir, "temp_storage"
)
hass.config.config_dir = await hass.async_add_executor_job( class NamedTupleSubclass(NamedTuple):
tmpdir.mkdir, "temp_storage" """A NamedTuple subclass."""
)
class NamedTupleSubclass(NamedTuple): name: str
"""A NamedTuple subclass."""
name: str nts = NamedTupleSubclass("a")
nts = NamedTupleSubclass("a") data = {
"named_tuple_subclass": nts,
"rgb_color": RGBColor(255, 255, 0),
"set": {1, 2, 3},
"list": [1, 2, 3],
"tuple": (1, 2, 3),
"dict_with_int": {1: 1, 2: 2},
"dict_with_named_tuple": {1: nts, 2: nts},
}
data = { store = storage.Store(
"named_tuple_subclass": nts, hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
"rgb_color": RGBColor(255, 255, 0), )
"set": {1, 2, 3}, await store.async_save(data)
"list": [1, 2, 3], load = await store.async_load()
"tuple": (1, 2, 3), assert load == {
"dict_with_int": {1: 1, 2: 2}, "dict_with_int": {"1": 1, "2": 2},
"dict_with_named_tuple": {1: nts, 2: nts}, "dict_with_named_tuple": {"1": ["a"], "2": ["a"]},
} "list": [1, 2, 3],
"named_tuple_subclass": ["a"],
"rgb_color": [255, 255, 0],
"set": [1, 2, 3],
"tuple": [1, 2, 3],
}
store = storage.Store( await hass.async_stop(force=True)
hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
)
await store.async_save(data)
load = await store.async_load()
assert load == {
"dict_with_int": {"1": 1, "2": 2},
"dict_with_named_tuple": {"1": ["a"], "2": ["a"]},
"list": [1, 2, 3],
"named_tuple_subclass": ["a"],
"rgb_color": [255, 255, 0],
"set": [1, 2, 3],
"tuple": [1, 2, 3],
}
await hass.async_stop(force=True)
async def test_loading_corrupt_core_file( async def test_loading_corrupt_core_file(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test we handle unrecoverable corruption in a core file.""" """Test we handle unrecoverable corruption in a core file."""
loop = asyncio.get_running_loop() async with async_test_home_assistant() as hass:
hass = await async_test_home_assistant(loop) tmp_storage = await hass.async_add_executor_job(tmpdir.mkdir, "temp_storage")
hass.config.config_dir = tmp_storage
tmp_storage = await hass.async_add_executor_job(tmpdir.mkdir, "temp_storage") storage_key = "core.anything"
hass.config.config_dir = tmp_storage store = storage.Store(
hass, MOCK_VERSION_2, storage_key, minor_version=MOCK_MINOR_VERSION_1
)
await store.async_save({"hello": "world"})
storage_path = os.path.join(tmp_storage, ".storage")
store_file = os.path.join(storage_path, store.key)
storage_key = "core.anything" data = await store.async_load()
store = storage.Store( assert data == {"hello": "world"}
hass, MOCK_VERSION_2, storage_key, minor_version=MOCK_MINOR_VERSION_1
)
await store.async_save({"hello": "world"})
storage_path = os.path.join(tmp_storage, ".storage")
store_file = os.path.join(storage_path, store.key)
data = await store.async_load() def _corrupt_store():
assert data == {"hello": "world"} with open(store_file, "w") as f:
f.write("corrupt")
def _corrupt_store(): await hass.async_add_executor_job(_corrupt_store)
with open(store_file, "w") as f:
f.write("corrupt")
await hass.async_add_executor_job(_corrupt_store) data = await store.async_load()
assert data is None
assert "Unrecoverable error decoding storage" in caplog.text
data = await store.async_load() issue_registry = ir.async_get(hass)
assert data is None found_issue = None
assert "Unrecoverable error decoding storage" in caplog.text issue_entry = None
for (domain, issue), entry in issue_registry.issues.items():
if domain == HOMEASSISTANT_DOMAIN and issue.startswith(
f"storage_corruption_{storage_key}_"
):
found_issue = issue
issue_entry = entry
break
issue_registry = ir.async_get(hass) assert found_issue is not None
found_issue = None assert issue_entry is not None
issue_entry = None assert issue_entry.is_fixable is True
for (domain, issue), entry in issue_registry.issues.items(): assert issue_entry.translation_placeholders["storage_key"] == storage_key
if domain == HOMEASSISTANT_DOMAIN and issue.startswith( assert issue_entry.issue_domain == HOMEASSISTANT_DOMAIN
f"storage_corruption_{storage_key}_" assert (
): issue_entry.translation_placeholders["error"]
found_issue = issue == "unexpected character: line 1 column 1 (char 0)"
issue_entry = entry )
break
assert found_issue is not None files = await hass.async_add_executor_job(
assert issue_entry is not None os.listdir, os.path.join(tmp_storage, ".storage")
assert issue_entry.is_fixable is True )
assert issue_entry.translation_placeholders["storage_key"] == storage_key assert ".corrupt" in files[0]
assert issue_entry.issue_domain == HOMEASSISTANT_DOMAIN
assert (
issue_entry.translation_placeholders["error"]
== "unexpected character: line 1 column 1 (char 0)"
)
files = await hass.async_add_executor_job( await hass.async_stop(force=True)
os.listdir, os.path.join(tmp_storage, ".storage")
)
assert ".corrupt" in files[0]
await hass.async_stop(force=True)
async def test_loading_corrupt_file_known_domain( async def test_loading_corrupt_file_known_domain(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test we handle unrecoverable corruption for a known domain.""" """Test we handle unrecoverable corruption for a known domain."""
loop = asyncio.get_running_loop() async with async_test_home_assistant() as hass:
hass = await async_test_home_assistant(loop) hass.config.components.add("testdomain")
hass.config.components.add("testdomain") storage_key = "testdomain.testkey"
storage_key = "testdomain.testkey"
tmp_storage = await hass.async_add_executor_job(tmpdir.mkdir, "temp_storage") tmp_storage = await hass.async_add_executor_job(tmpdir.mkdir, "temp_storage")
hass.config.config_dir = tmp_storage hass.config.config_dir = tmp_storage
store = storage.Store( store = storage.Store(
hass, MOCK_VERSION_2, storage_key, minor_version=MOCK_MINOR_VERSION_1 hass, MOCK_VERSION_2, storage_key, minor_version=MOCK_MINOR_VERSION_1
) )
await store.async_save({"hello": "world"}) await store.async_save({"hello": "world"})
storage_path = os.path.join(tmp_storage, ".storage") storage_path = os.path.join(tmp_storage, ".storage")
store_file = os.path.join(storage_path, store.key) store_file = os.path.join(storage_path, store.key)
data = await store.async_load() data = await store.async_load()
assert data == {"hello": "world"} assert data == {"hello": "world"}
def _corrupt_store(): def _corrupt_store():
with open(store_file, "w") as f: with open(store_file, "w") as f:
f.write('{"valid":"json"}..with..corrupt') f.write('{"valid":"json"}..with..corrupt')
await hass.async_add_executor_job(_corrupt_store) await hass.async_add_executor_job(_corrupt_store)
data = await store.async_load() data = await store.async_load()
assert data is None assert data is None
assert "Unrecoverable error decoding storage" in caplog.text assert "Unrecoverable error decoding storage" in caplog.text
issue_registry = ir.async_get(hass) issue_registry = ir.async_get(hass)
found_issue = None found_issue = None
issue_entry = None issue_entry = None
for (domain, issue), entry in issue_registry.issues.items(): for (domain, issue), entry in issue_registry.issues.items():
if domain == HOMEASSISTANT_DOMAIN and issue.startswith( if domain == HOMEASSISTANT_DOMAIN and issue.startswith(
f"storage_corruption_{storage_key}_" f"storage_corruption_{storage_key}_"
): ):
found_issue = issue found_issue = issue
issue_entry = entry issue_entry = entry
break break
assert found_issue is not None assert found_issue is not None
assert issue_entry is not None assert issue_entry is not None
assert issue_entry.is_fixable is True assert issue_entry.is_fixable is True
assert issue_entry.translation_placeholders["storage_key"] == storage_key assert issue_entry.translation_placeholders["storage_key"] == storage_key
assert issue_entry.issue_domain == "testdomain" assert issue_entry.issue_domain == "testdomain"
assert ( assert (
issue_entry.translation_placeholders["error"] issue_entry.translation_placeholders["error"]
== "unexpected content after document: line 1 column 17 (char 16)" == "unexpected content after document: line 1 column 17 (char 16)"
) )
files = await hass.async_add_executor_job( files = await hass.async_add_executor_job(
os.listdir, os.path.join(tmp_storage, ".storage") os.listdir, os.path.join(tmp_storage, ".storage")
) )
assert ".corrupt" in files[0] assert ".corrupt" in files[0]
await hass.async_stop(force=True) await hass.async_stop(force=True)
async def test_os_error_is_fatal(tmpdir: py.path.local) -> None: async def test_os_error_is_fatal(tmpdir: py.path.local) -> None:
"""Test OSError during load is fatal.""" """Test OSError during load is fatal."""
loop = asyncio.get_running_loop() async with async_test_home_assistant() as hass:
hass = await async_test_home_assistant(loop) tmp_storage = await hass.async_add_executor_job(tmpdir.mkdir, "temp_storage")
hass.config.config_dir = tmp_storage
tmp_storage = await hass.async_add_executor_job(tmpdir.mkdir, "temp_storage") store = storage.Store(
hass.config.config_dir = tmp_storage hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
)
await store.async_save({"hello": "world"})
store = storage.Store( with pytest.raises(OSError), patch(
hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1 "homeassistant.helpers.storage.json_util.load_json", side_effect=OSError
) ):
await store.async_save({"hello": "world"}) await store.async_load()
with pytest.raises(OSError), patch( base_os_error = OSError()
"homeassistant.helpers.storage.json_util.load_json", side_effect=OSError base_os_error.errno = 30
): home_assistant_error = HomeAssistantError()
await store.async_load() home_assistant_error.__cause__ = base_os_error
base_os_error = OSError() with pytest.raises(HomeAssistantError), patch(
base_os_error.errno = 30 "homeassistant.helpers.storage.json_util.load_json",
home_assistant_error = HomeAssistantError() side_effect=home_assistant_error,
home_assistant_error.__cause__ = base_os_error ):
await store.async_load()
with pytest.raises(HomeAssistantError), patch( await hass.async_stop(force=True)
"homeassistant.helpers.storage.json_util.load_json",
side_effect=home_assistant_error,
):
await store.async_load()
await hass.async_stop(force=True)
async def test_read_only_store( async def test_read_only_store(

View file

@ -1,5 +1,4 @@
"""Tests for the storage helper with minimal mocking.""" """Tests for the storage helper with minimal mocking."""
import asyncio
from datetime import timedelta from datetime import timedelta
import os import os
from unittest.mock import patch from unittest.mock import patch
@ -15,24 +14,26 @@ from tests.common import async_fire_time_changed, async_test_home_assistant
async def test_removing_while_delay_in_progress(tmpdir: py.path.local) -> None: async def test_removing_while_delay_in_progress(tmpdir: py.path.local) -> None:
"""Test removing while delay in progress.""" """Test removing while delay in progress."""
loop = asyncio.get_event_loop() async with async_test_home_assistant() as hass:
hass = await async_test_home_assistant(loop) test_dir = await hass.async_add_executor_job(tmpdir.mkdir, "storage")
test_dir = await hass.async_add_executor_job(tmpdir.mkdir, "storage") with patch.object(storage, "STORAGE_DIR", test_dir):
real_store = storage.Store(hass, 1, "remove_me")
with patch.object(storage, "STORAGE_DIR", test_dir): await real_store.async_save({"delay": "no"})
real_store = storage.Store(hass, 1, "remove_me")
await real_store.async_save({"delay": "no"}) assert await hass.async_add_executor_job(os.path.exists, real_store.path)
assert await hass.async_add_executor_job(os.path.exists, real_store.path) real_store.async_delay_save(lambda: {"delay": "yes"}, 1)
real_store.async_delay_save(lambda: {"delay": "yes"}, 1) await real_store.async_remove()
assert not await hass.async_add_executor_job(
os.path.exists, real_store.path
)
await real_store.async_remove() async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1))
assert not await hass.async_add_executor_job(os.path.exists, real_store.path) await hass.async_block_till_done()
assert not await hass.async_add_executor_job(
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1)) os.path.exists, real_store.path
await hass.async_block_till_done() )
assert not await hass.async_add_executor_job(os.path.exists, real_store.path) await hass.async_stop()
await hass.async_stop()