Detect if mysql and sqlite support row_number (#57475)
This commit is contained in:
parent
3ff30f53a7
commit
0139bfa749
5 changed files with 180 additions and 46 deletions
|
@ -413,6 +413,7 @@ class Recorder(threading.Thread):
|
|||
self.async_migration_event = asyncio.Event()
|
||||
self.migration_in_progress = False
|
||||
self._queue_watcher = None
|
||||
self._db_supports_row_number = True
|
||||
|
||||
self.enabled = True
|
||||
|
||||
|
@ -972,6 +973,7 @@ class Recorder(threading.Thread):
|
|||
def setup_recorder_connection(dbapi_connection, connection_record):
|
||||
"""Dbapi specific connection settings."""
|
||||
setup_connection_for_dialect(
|
||||
self,
|
||||
self.engine.dialect.name,
|
||||
dbapi_connection,
|
||||
not self._completed_first_database_setup,
|
||||
|
|
|
@ -89,6 +89,13 @@ QUERY_STATISTICS_SUMMARY_SUM = [
|
|||
.label("rownum"),
|
||||
]
|
||||
|
||||
QUERY_STATISTICS_SUMMARY_SUM_LEGACY = [
|
||||
StatisticsShortTerm.metadata_id,
|
||||
StatisticsShortTerm.last_reset,
|
||||
StatisticsShortTerm.state,
|
||||
StatisticsShortTerm.sum,
|
||||
]
|
||||
|
||||
QUERY_STATISTIC_META = [
|
||||
StatisticsMeta.id,
|
||||
StatisticsMeta.statistic_id,
|
||||
|
@ -275,37 +282,81 @@ def compile_hourly_statistics(
|
|||
}
|
||||
|
||||
# Get last hour's last sum
|
||||
subquery = (
|
||||
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
|
||||
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
|
||||
.filter(StatisticsShortTerm.start < bindparam("end_time"))
|
||||
.subquery()
|
||||
)
|
||||
query = (
|
||||
session.query(subquery)
|
||||
.filter(subquery.c.rownum == 1)
|
||||
.order_by(subquery.c.metadata_id)
|
||||
)
|
||||
stats = execute(query.params(start_time=start_time, end_time=end_time))
|
||||
if instance._db_supports_row_number: # pylint: disable=[protected-access]
|
||||
subquery = (
|
||||
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
|
||||
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
|
||||
.filter(StatisticsShortTerm.start < bindparam("end_time"))
|
||||
.subquery()
|
||||
)
|
||||
query = (
|
||||
session.query(subquery)
|
||||
.filter(subquery.c.rownum == 1)
|
||||
.order_by(subquery.c.metadata_id)
|
||||
)
|
||||
stats = execute(query.params(start_time=start_time, end_time=end_time))
|
||||
|
||||
if stats:
|
||||
for stat in stats:
|
||||
metadata_id, start, last_reset, state, _sum, _ = stat
|
||||
if metadata_id in summary:
|
||||
summary[metadata_id].update(
|
||||
{
|
||||
if stats:
|
||||
for stat in stats:
|
||||
metadata_id, start, last_reset, state, _sum, _ = stat
|
||||
if metadata_id in summary:
|
||||
summary[metadata_id].update(
|
||||
{
|
||||
"last_reset": process_timestamp(last_reset),
|
||||
"state": state,
|
||||
"sum": _sum,
|
||||
}
|
||||
)
|
||||
else:
|
||||
summary[metadata_id] = {
|
||||
"start": start_time,
|
||||
"last_reset": process_timestamp(last_reset),
|
||||
"state": state,
|
||||
"sum": _sum,
|
||||
}
|
||||
else:
|
||||
baked_query = instance.hass.data[STATISTICS_SHORT_TERM_BAKERY](
|
||||
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsShortTerm.start >= bindparam("start_time")
|
||||
)
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsShortTerm.start < bindparam("end_time")
|
||||
)
|
||||
baked_query += lambda q: q.order_by(
|
||||
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
|
||||
)
|
||||
|
||||
stats = execute(
|
||||
baked_query(session).params(start_time=start_time, end_time=end_time)
|
||||
)
|
||||
|
||||
if stats:
|
||||
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore
|
||||
(
|
||||
metadata_id,
|
||||
last_reset,
|
||||
state,
|
||||
_sum,
|
||||
) = next(group)
|
||||
if metadata_id in summary:
|
||||
summary[metadata_id].update(
|
||||
{
|
||||
"start": start_time,
|
||||
"last_reset": process_timestamp(last_reset),
|
||||
"state": state,
|
||||
"sum": _sum,
|
||||
}
|
||||
)
|
||||
else:
|
||||
summary[metadata_id] = {
|
||||
"start": start_time,
|
||||
"last_reset": process_timestamp(last_reset),
|
||||
"state": state,
|
||||
"sum": _sum,
|
||||
}
|
||||
)
|
||||
else:
|
||||
summary[metadata_id] = {
|
||||
"start": start_time,
|
||||
"last_reset": process_timestamp(last_reset),
|
||||
"state": state,
|
||||
"sum": _sum,
|
||||
}
|
||||
|
||||
# Insert compiled hourly statistics in the database
|
||||
for metadata_id, stat in summary.items():
|
||||
|
|
|
@ -266,7 +266,18 @@ def execute_on_connection(dbapi_connection, statement):
|
|||
cursor.close()
|
||||
|
||||
|
||||
def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connection):
|
||||
def query_on_connection(dbapi_connection, statement):
|
||||
"""Execute a single statement with a dbapi connection and return the result."""
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(statement)
|
||||
result = cursor.fetchall()
|
||||
cursor.close()
|
||||
return result
|
||||
|
||||
|
||||
def setup_connection_for_dialect(
|
||||
instance, dialect_name, dbapi_connection, first_connection
|
||||
):
|
||||
"""Execute statements needed for dialect connection."""
|
||||
# Returns False if the the connection needs to be setup
|
||||
# on the next connection, returns True if the connection
|
||||
|
@ -280,6 +291,13 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
|
|||
# WAL mode only needs to be setup once
|
||||
# instead of every time we open the sqlite connection
|
||||
# as its persistent and isn't free to call every time.
|
||||
result = query_on_connection(dbapi_connection, "SELECT sqlite_version()")
|
||||
version = result[0][0]
|
||||
major, minor, _patch = version.split(".", 2)
|
||||
if int(major) == 3 and int(minor) < 25:
|
||||
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
|
||||
False
|
||||
)
|
||||
|
||||
# approximately 8MiB of memory
|
||||
execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192")
|
||||
|
@ -289,6 +307,14 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
|
|||
|
||||
if dialect_name == "mysql":
|
||||
execute_on_connection(dbapi_connection, "SET session wait_timeout=28800")
|
||||
if first_connection:
|
||||
result = query_on_connection(dbapi_connection, "SELECT VERSION()")
|
||||
version = result[0][0]
|
||||
major, minor, _patch = version.split(".", 2)
|
||||
if int(major) == 5 and int(minor) < 8:
|
||||
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
|
||||
False
|
||||
)
|
||||
|
||||
|
||||
def end_incomplete_runs(session, start_time):
|
||||
|
|
|
@ -122,44 +122,88 @@ async def test_last_run_was_recently_clean(hass):
|
|||
)
|
||||
|
||||
|
||||
def test_setup_connection_for_dialect_mysql():
|
||||
@pytest.mark.parametrize(
|
||||
"mysql_version, db_supports_row_number",
|
||||
[
|
||||
("10.0.0", True),
|
||||
("5.8.0", True),
|
||||
("5.7.0", False),
|
||||
],
|
||||
)
|
||||
def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number):
|
||||
"""Test setting up the connection for a mysql dialect."""
|
||||
execute_mock = MagicMock()
|
||||
instance_mock = MagicMock(_db_supports_row_number=True)
|
||||
execute_args = []
|
||||
close_mock = MagicMock()
|
||||
|
||||
def execute_mock(statement):
|
||||
nonlocal execute_args
|
||||
execute_args.append(statement)
|
||||
|
||||
def fetchall_mock():
|
||||
nonlocal execute_args
|
||||
if execute_args[-1] == "SELECT VERSION()":
|
||||
return [[mysql_version]]
|
||||
return None
|
||||
|
||||
def _make_cursor_mock(*_):
|
||||
return MagicMock(execute=execute_mock, close=close_mock)
|
||||
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)
|
||||
|
||||
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
|
||||
|
||||
util.setup_connection_for_dialect("mysql", dbapi_connection, True)
|
||||
util.setup_connection_for_dialect(instance_mock, "mysql", dbapi_connection, True)
|
||||
|
||||
assert execute_mock.call_args[0][0] == "SET session wait_timeout=28800"
|
||||
assert len(execute_args) == 2
|
||||
assert execute_args[0] == "SET session wait_timeout=28800"
|
||||
assert execute_args[1] == "SELECT VERSION()"
|
||||
|
||||
assert instance_mock._db_supports_row_number == db_supports_row_number
|
||||
|
||||
|
||||
def test_setup_connection_for_dialect_sqlite():
|
||||
@pytest.mark.parametrize(
|
||||
"sqlite_version, db_supports_row_number",
|
||||
[
|
||||
("3.25.0", True),
|
||||
("3.24.0", False),
|
||||
],
|
||||
)
|
||||
def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number):
|
||||
"""Test setting up the connection for a sqlite dialect."""
|
||||
execute_mock = MagicMock()
|
||||
instance_mock = MagicMock(_db_supports_row_number=True)
|
||||
execute_args = []
|
||||
close_mock = MagicMock()
|
||||
|
||||
def execute_mock(statement):
|
||||
nonlocal execute_args
|
||||
execute_args.append(statement)
|
||||
|
||||
def fetchall_mock():
|
||||
nonlocal execute_args
|
||||
if execute_args[-1] == "SELECT sqlite_version()":
|
||||
return [[sqlite_version]]
|
||||
return None
|
||||
|
||||
def _make_cursor_mock(*_):
|
||||
return MagicMock(execute=execute_mock, close=close_mock)
|
||||
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)
|
||||
|
||||
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
|
||||
|
||||
util.setup_connection_for_dialect("sqlite", dbapi_connection, True)
|
||||
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, True)
|
||||
|
||||
assert len(execute_mock.call_args_list) == 3
|
||||
assert execute_mock.call_args_list[0][0][0] == "PRAGMA journal_mode=WAL"
|
||||
assert execute_mock.call_args_list[1][0][0] == "PRAGMA cache_size = -8192"
|
||||
assert execute_mock.call_args_list[2][0][0] == "PRAGMA foreign_keys=ON"
|
||||
assert len(execute_args) == 4
|
||||
assert execute_args[0] == "PRAGMA journal_mode=WAL"
|
||||
assert execute_args[1] == "SELECT sqlite_version()"
|
||||
assert execute_args[2] == "PRAGMA cache_size = -8192"
|
||||
assert execute_args[3] == "PRAGMA foreign_keys=ON"
|
||||
|
||||
execute_mock.reset_mock()
|
||||
util.setup_connection_for_dialect("sqlite", dbapi_connection, False)
|
||||
execute_args = []
|
||||
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, False)
|
||||
|
||||
assert len(execute_mock.call_args_list) == 2
|
||||
assert execute_mock.call_args_list[0][0][0] == "PRAGMA cache_size = -8192"
|
||||
assert execute_mock.call_args_list[1][0][0] == "PRAGMA foreign_keys=ON"
|
||||
assert len(execute_args) == 2
|
||||
assert execute_args[0] == "PRAGMA cache_size = -8192"
|
||||
assert execute_args[1] == "PRAGMA foreign_keys=ON"
|
||||
|
||||
assert instance_mock._db_supports_row_number == db_supports_row_number
|
||||
|
||||
|
||||
def test_basic_sanity_check(hass_recorder):
|
||||
|
|
|
@ -1806,7 +1806,13 @@ def test_compile_hourly_statistics_changing_statistics(
|
|||
assert "Error while processing event StatisticsTask" not in caplog.text
|
||||
|
||||
|
||||
def test_compile_statistics_hourly_summary(hass_recorder, caplog):
|
||||
@pytest.mark.parametrize(
|
||||
"db_supports_row_number,in_log,not_in_log",
|
||||
[(True, "row_number", None), (False, None, "row_number")],
|
||||
)
|
||||
def test_compile_statistics_hourly_summary(
|
||||
hass_recorder, caplog, db_supports_row_number, in_log, not_in_log
|
||||
):
|
||||
"""Test compiling hourly statistics."""
|
||||
zero = dt_util.utcnow()
|
||||
zero = zero.replace(minute=0, second=0, microsecond=0)
|
||||
|
@ -1815,6 +1821,7 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
|
|||
zero += timedelta(hours=1)
|
||||
hass = hass_recorder()
|
||||
recorder = hass.data[DATA_INSTANCE]
|
||||
recorder._db_supports_row_number = db_supports_row_number
|
||||
setup_component(hass, "sensor", {})
|
||||
attributes = {
|
||||
"device_class": None,
|
||||
|
@ -2052,6 +2059,10 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
|
|||
end += timedelta(hours=1)
|
||||
assert stats == expected_stats
|
||||
assert "Error while processing event StatisticsTask" not in caplog.text
|
||||
if in_log:
|
||||
assert in_log in caplog.text
|
||||
if not_in_log:
|
||||
assert not_in_log not in caplog.text
|
||||
|
||||
|
||||
def record_states(hass, zero, entity_id, attributes, seq=None):
|
||||
|
|
Loading…
Add table
Reference in a new issue