* Speed up comparing State and Event objects Use default python implementation for State and Event __hash__ and __eq__ The default implementation compared based on the id() of the object which is effectively what we want here anyways. These overrides are left over from the days when these used to be attrs objects By avoiding implementing these ourselves all of the equality checks can happen in native code * tweak * adjust tests * write out some more * fix test to not compare objects * more test fixes * more test fixes * correct stats tests * fix more tests * fix more tests * update sensor recorder tests
223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
"""Common test utils for working with recorder."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
import time
|
|
from typing import Any, Literal, cast
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm.session import Session
|
|
|
|
from homeassistant import core as ha
|
|
from homeassistant.components import recorder
|
|
from homeassistant.components.recorder import get_instance, statistics
|
|
from homeassistant.components.recorder.core import Recorder
|
|
from homeassistant.components.recorder.db_schema import RecorderRuns
|
|
from homeassistant.components.recorder.tasks import RecorderTask, StatisticsTask
|
|
from homeassistant.core import Event, HomeAssistant, State
|
|
from homeassistant.util import dt as dt_util
|
|
|
|
from . import db_schema_0
|
|
|
|
DEFAULT_PURGE_TASKS = 3
|
|
|
|
|
|
@dataclass
|
|
class BlockRecorderTask(RecorderTask):
|
|
"""A task to block the recorder for testing only."""
|
|
|
|
event: asyncio.Event
|
|
seconds: float
|
|
|
|
def run(self, instance: Recorder) -> None:
|
|
"""Block the recorders event loop."""
|
|
instance.hass.loop.call_soon_threadsafe(self.event.set)
|
|
time.sleep(self.seconds)
|
|
|
|
|
|
async def async_block_recorder(hass: HomeAssistant, seconds: float) -> None:
|
|
"""Block the recorders event loop for testing.
|
|
|
|
Returns as soon as the recorder has started the block.
|
|
|
|
Does not wait for the block to finish.
|
|
"""
|
|
event = asyncio.Event()
|
|
get_instance(hass).queue_task(BlockRecorderTask(event, seconds))
|
|
await event.wait()
|
|
|
|
|
|
def do_adhoc_statistics(hass: HomeAssistant, **kwargs: Any) -> None:
|
|
"""Trigger an adhoc statistics run."""
|
|
if not (start := kwargs.get("start")):
|
|
start = statistics.get_start_time()
|
|
get_instance(hass).queue_task(StatisticsTask(start, False))
|
|
|
|
|
|
def wait_recording_done(hass: HomeAssistant) -> None:
|
|
"""Block till recording is done."""
|
|
hass.block_till_done()
|
|
trigger_db_commit(hass)
|
|
hass.block_till_done()
|
|
recorder.get_instance(hass).block_till_done()
|
|
hass.block_till_done()
|
|
|
|
|
|
def trigger_db_commit(hass: HomeAssistant) -> None:
|
|
"""Force the recorder to commit."""
|
|
recorder.get_instance(hass)._async_commit(dt_util.utcnow())
|
|
|
|
|
|
async def async_wait_recording_done(hass: HomeAssistant) -> None:
|
|
"""Async wait until recording is done."""
|
|
await hass.async_block_till_done()
|
|
async_trigger_db_commit(hass)
|
|
await hass.async_block_till_done()
|
|
await async_recorder_block_till_done(hass)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def async_wait_purge_done(hass: HomeAssistant, max: int = None) -> None:
|
|
"""Wait for max number of purge events.
|
|
|
|
Because a purge may insert another PurgeTask into
|
|
the queue after the WaitTask finishes, we need up to
|
|
a maximum number of WaitTasks that we will put into the
|
|
queue.
|
|
"""
|
|
if not max:
|
|
max = DEFAULT_PURGE_TASKS
|
|
for _ in range(max + 1):
|
|
await async_wait_recording_done(hass)
|
|
|
|
|
|
@ha.callback
|
|
def async_trigger_db_commit(hass: HomeAssistant) -> None:
|
|
"""Force the recorder to commit. Async friendly."""
|
|
recorder.get_instance(hass)._async_commit(dt_util.utcnow())
|
|
|
|
|
|
async def async_recorder_block_till_done(hass: HomeAssistant) -> None:
|
|
"""Non blocking version of recorder.block_till_done()."""
|
|
await hass.async_add_executor_job(recorder.get_instance(hass).block_till_done)
|
|
|
|
|
|
def corrupt_db_file(test_db_file):
|
|
"""Corrupt an sqlite3 database file."""
|
|
with open(test_db_file, "w+") as fhandle:
|
|
fhandle.seek(200)
|
|
fhandle.write("I am a corrupt db" * 100)
|
|
|
|
|
|
def create_engine_test(*args, **kwargs):
|
|
"""Test version of create_engine that initializes with old schema.
|
|
|
|
This simulates an existing db with the old schema.
|
|
"""
|
|
engine = create_engine(*args, **kwargs)
|
|
db_schema_0.Base.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
def run_information_with_session(
|
|
session: Session, point_in_time: datetime | None = None
|
|
) -> RecorderRuns | None:
|
|
"""Return information about current run from the database."""
|
|
recorder_runs = RecorderRuns
|
|
|
|
query = session.query(recorder_runs)
|
|
if point_in_time:
|
|
query = query.filter(
|
|
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
|
|
)
|
|
|
|
if (res := query.first()) is not None:
|
|
session.expunge(res)
|
|
return cast(RecorderRuns, res)
|
|
return res
|
|
|
|
|
|
def statistics_during_period(
|
|
hass: HomeAssistant,
|
|
start_time: datetime,
|
|
end_time: datetime | None = None,
|
|
statistic_ids: list[str] | None = None,
|
|
period: Literal["5minute", "day", "hour", "week", "month"] = "hour",
|
|
units: dict[str, str] | None = None,
|
|
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]]
|
|
| None = None,
|
|
) -> dict[str, list[dict[str, Any]]]:
|
|
"""Call statistics_during_period with defaults for simpler tests."""
|
|
if types is None:
|
|
types = {"last_reset", "max", "mean", "min", "state", "sum"}
|
|
return statistics.statistics_during_period(
|
|
hass, start_time, end_time, statistic_ids, period, units, types
|
|
)
|
|
|
|
|
|
def assert_states_equal_without_context(state: State, other: State) -> None:
|
|
"""Assert that two states are equal, ignoring context."""
|
|
assert_states_equal_without_context_and_last_changed(state, other)
|
|
assert state.last_changed == other.last_changed
|
|
|
|
|
|
def assert_states_equal_without_context_and_last_changed(
|
|
state: State, other: State
|
|
) -> None:
|
|
"""Assert that two states are equal, ignoring context and last_changed."""
|
|
assert state.state == other.state
|
|
assert state.attributes == other.attributes
|
|
assert state.last_updated == other.last_updated
|
|
|
|
|
|
def assert_multiple_states_equal_without_context_and_last_changed(
|
|
states: Iterable[State], others: Iterable[State]
|
|
) -> None:
|
|
"""Assert that multiple states are equal, ignoring context and last_changed."""
|
|
states_list = list(states)
|
|
others_list = list(others)
|
|
assert len(states_list) == len(others_list)
|
|
for i, state in enumerate(states_list):
|
|
assert_states_equal_without_context_and_last_changed(state, others_list[i])
|
|
|
|
|
|
def assert_multiple_states_equal_without_context(
|
|
states: Iterable[State], others: Iterable[State]
|
|
) -> None:
|
|
"""Assert that multiple states are equal, ignoring context."""
|
|
states_list = list(states)
|
|
others_list = list(others)
|
|
assert len(states_list) == len(others_list)
|
|
for i, state in enumerate(states_list):
|
|
assert_states_equal_without_context(state, others_list[i])
|
|
|
|
|
|
def assert_events_equal_without_context(event: Event, other: Event) -> None:
|
|
"""Assert that two events are equal, ignoring context."""
|
|
assert event.data == other.data
|
|
assert event.event_type == other.event_type
|
|
assert event.origin == other.origin
|
|
assert event.time_fired == other.time_fired
|
|
|
|
|
|
def assert_dict_of_states_equal_without_context(
|
|
states: dict[str, list[State]], others: dict[str, list[State]]
|
|
) -> None:
|
|
"""Assert that two dicts of states are equal, ignoring context."""
|
|
assert len(states) == len(others)
|
|
for entity_id, state in states.items():
|
|
assert_multiple_states_equal_without_context(state, others[entity_id])
|
|
|
|
|
|
def assert_dict_of_states_equal_without_context_and_last_changed(
|
|
states: dict[str, list[State]], others: dict[str, list[State]]
|
|
) -> None:
|
|
"""Assert that two dicts of states are equal, ignoring context and last_changed."""
|
|
assert len(states) == len(others)
|
|
for entity_id, state in states.items():
|
|
assert_multiple_states_equal_without_context_and_last_changed(
|
|
state, others[entity_id]
|
|
)
|