hass-core/tests/components/recorder/conftest.py
Erik Montnemery f1d6ad9073
Add test fixture to control recorder migration (#121180)
* Add test fixture to control recorder migration

* Update tests/components/recorder/conftest.py

Co-authored-by: J. Nick Koston <nick@koston.org>

* Update tests/components/recorder/conftest.py

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
2024-07-04 13:10:08 +02:00

119 lines
4 KiB
Python

"""Fixtures for the recorder component tests."""
from dataclasses import dataclass
import threading
from unittest.mock import Mock, patch
import pytest
from typing_extensions import AsyncGenerator, Generator
from homeassistant.components import recorder
from homeassistant.components.recorder import db_schema
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
def pytest_configure(config):
"""Add custom skip_on_db_engine marker."""
config.addinivalue_line(
"markers",
"skip_on_db_engine(engine): mark test to run only on named DB engine(s)",
)
@pytest.fixture
def skip_by_db_engine(request: pytest.FixtureRequest, recorder_db_url: str) -> None:
"""Fixture to skip tests on unsupported DB engines.
Mark the test with @pytest.mark.skip_on_db_engine("mysql") to skip on mysql, or
@pytest.mark.skip_on_db_engine(["mysql", "sqlite"]) to skip on mysql and sqlite.
"""
if request.node.get_closest_marker("skip_on_db_engine"):
skip_on_db_engine = request.node.get_closest_marker("skip_on_db_engine").args[0]
if isinstance(skip_on_db_engine, str):
skip_on_db_engine = [skip_on_db_engine]
db_engine = recorder_db_url.partition("://")[0]
if db_engine in skip_on_db_engine:
pytest.skip(f"skipped for DB engine: {db_engine}")
@pytest.fixture
def recorder_dialect_name(hass: HomeAssistant, db_engine: str) -> Generator[None]:
"""Patch the recorder dialect."""
if instance := hass.data.get(recorder.DATA_INSTANCE):
instance.__dict__.pop("dialect_name", None)
with patch.object(instance, "_dialect_name", db_engine):
yield
instance.__dict__.pop("dialect_name", None)
else:
with patch(
"homeassistant.components.recorder.Recorder.dialect_name", db_engine
):
yield
@dataclass(slots=True)
class InstrumentedMigration:
"""Container to aid controlling migration progress."""
migration_done: threading.Event
migration_stall: threading.Event
migration_started: threading.Event
migration_version: int | None
apply_update_mock: Mock
@pytest.fixture
async def instrument_migration(
hass: HomeAssistant,
) -> AsyncGenerator[InstrumentedMigration]:
"""Instrument recorder migration."""
real_migrate_schema = recorder.migration.migrate_schema
real_apply_update = recorder.migration._apply_update
def _instrument_migrate_schema(*args):
"""Control migration progress and check results."""
instrumented_migration.migration_started.set()
try:
real_migrate_schema(*args)
except Exception:
instrumented_migration.migration_done.set()
raise
# Check and report the outcome of the migration; if migration fails
# the recorder will silently create a new database.
with session_scope(hass=hass, read_only=True) as session:
res = (
session.query(db_schema.SchemaChanges)
.order_by(db_schema.SchemaChanges.change_id.desc())
.first()
)
instrumented_migration.migration_version = res.schema_version
instrumented_migration.migration_done.set()
def _instrument_apply_update(*args):
"""Control migration progress."""
instrumented_migration.migration_stall.wait()
real_apply_update(*args)
with (
patch(
"homeassistant.components.recorder.migration.migrate_schema",
wraps=_instrument_migrate_schema,
),
patch(
"homeassistant.components.recorder.migration._apply_update",
wraps=_instrument_apply_update,
) as apply_update_mock,
):
instrumented_migration = InstrumentedMigration(
migration_done=threading.Event(),
migration_stall=threading.Event(),
migration_started=threading.Event(),
migration_version=None,
apply_update_mock=apply_update_mock,
)
yield instrumented_migration