Detect if mysql and sqlite support row_number (#57475)

This commit is contained in:
Erik Montnemery 2021-10-12 06:17:18 +02:00 committed by GitHub
parent 3ff30f53a7
commit 0139bfa749
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 180 additions and 46 deletions

View file

@ -413,6 +413,7 @@ class Recorder(threading.Thread):
self.async_migration_event = asyncio.Event()
self.migration_in_progress = False
self._queue_watcher = None
self._db_supports_row_number = True
self.enabled = True
@ -972,6 +973,7 @@ class Recorder(threading.Thread):
def setup_recorder_connection(dbapi_connection, connection_record):
"""Dbapi specific connection settings."""
setup_connection_for_dialect(
self,
self.engine.dialect.name,
dbapi_connection,
not self._completed_first_database_setup,

View file

@ -89,6 +89,13 @@ QUERY_STATISTICS_SUMMARY_SUM = [
.label("rownum"),
]
QUERY_STATISTICS_SUMMARY_SUM_LEGACY = [
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.last_reset,
StatisticsShortTerm.state,
StatisticsShortTerm.sum,
]
QUERY_STATISTIC_META = [
StatisticsMeta.id,
StatisticsMeta.statistic_id,
@ -275,6 +282,7 @@ def compile_hourly_statistics(
}
# Get last hour's last sum
if instance._db_supports_row_number: # pylint: disable=[protected-access]
subquery = (
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
@ -306,6 +314,49 @@ def compile_hourly_statistics(
"state": state,
"sum": _sum,
}
else:
baked_query = instance.hass.data[STATISTICS_SHORT_TERM_BAKERY](
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
)
baked_query += lambda q: q.filter(
StatisticsShortTerm.start >= bindparam("start_time")
)
baked_query += lambda q: q.filter(
StatisticsShortTerm.start < bindparam("end_time")
)
baked_query += lambda q: q.order_by(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
)
stats = execute(
baked_query(session).params(start_time=start_time, end_time=end_time)
)
if stats:
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore
(
metadata_id,
last_reset,
state,
_sum,
) = next(group)
if metadata_id in summary:
summary[metadata_id].update(
{
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
# Insert compiled hourly statistics in the database
for metadata_id, stat in summary.items():

View file

@ -266,7 +266,18 @@ def execute_on_connection(dbapi_connection, statement):
cursor.close()
def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connection):
def query_on_connection(dbapi_connection, statement):
"""Execute a single statement with a dbapi connection and return the result."""
cursor = dbapi_connection.cursor()
cursor.execute(statement)
result = cursor.fetchall()
cursor.close()
return result
def setup_connection_for_dialect(
instance, dialect_name, dbapi_connection, first_connection
):
"""Execute statements needed for dialect connection."""
# Returns False if the the connection needs to be setup
# on the next connection, returns True if the connection
@ -280,6 +291,13 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
# WAL mode only needs to be setup once
# instead of every time we open the sqlite connection
# as its persistent and isn't free to call every time.
result = query_on_connection(dbapi_connection, "SELECT sqlite_version()")
version = result[0][0]
major, minor, _patch = version.split(".", 2)
if int(major) == 3 and int(minor) < 25:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
# approximately 8MiB of memory
execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192")
@ -289,6 +307,14 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
if dialect_name == "mysql":
execute_on_connection(dbapi_connection, "SET session wait_timeout=28800")
if first_connection:
result = query_on_connection(dbapi_connection, "SELECT VERSION()")
version = result[0][0]
major, minor, _patch = version.split(".", 2)
if int(major) == 5 and int(minor) < 8:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
def end_incomplete_runs(session, start_time):

View file

@ -122,44 +122,88 @@ async def test_last_run_was_recently_clean(hass):
)
def test_setup_connection_for_dialect_mysql():
@pytest.mark.parametrize(
"mysql_version, db_supports_row_number",
[
("10.0.0", True),
("5.8.0", True),
("5.7.0", False),
],
)
def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number):
"""Test setting up the connection for a mysql dialect."""
execute_mock = MagicMock()
instance_mock = MagicMock(_db_supports_row_number=True)
execute_args = []
close_mock = MagicMock()
def execute_mock(statement):
nonlocal execute_args
execute_args.append(statement)
def fetchall_mock():
nonlocal execute_args
if execute_args[-1] == "SELECT VERSION()":
return [[mysql_version]]
return None
def _make_cursor_mock(*_):
return MagicMock(execute=execute_mock, close=close_mock)
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
util.setup_connection_for_dialect("mysql", dbapi_connection, True)
util.setup_connection_for_dialect(instance_mock, "mysql", dbapi_connection, True)
assert execute_mock.call_args[0][0] == "SET session wait_timeout=28800"
assert len(execute_args) == 2
assert execute_args[0] == "SET session wait_timeout=28800"
assert execute_args[1] == "SELECT VERSION()"
assert instance_mock._db_supports_row_number == db_supports_row_number
def test_setup_connection_for_dialect_sqlite():
@pytest.mark.parametrize(
"sqlite_version, db_supports_row_number",
[
("3.25.0", True),
("3.24.0", False),
],
)
def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number):
"""Test setting up the connection for a sqlite dialect."""
execute_mock = MagicMock()
instance_mock = MagicMock(_db_supports_row_number=True)
execute_args = []
close_mock = MagicMock()
def execute_mock(statement):
nonlocal execute_args
execute_args.append(statement)
def fetchall_mock():
nonlocal execute_args
if execute_args[-1] == "SELECT sqlite_version()":
return [[sqlite_version]]
return None
def _make_cursor_mock(*_):
return MagicMock(execute=execute_mock, close=close_mock)
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
util.setup_connection_for_dialect("sqlite", dbapi_connection, True)
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, True)
assert len(execute_mock.call_args_list) == 3
assert execute_mock.call_args_list[0][0][0] == "PRAGMA journal_mode=WAL"
assert execute_mock.call_args_list[1][0][0] == "PRAGMA cache_size = -8192"
assert execute_mock.call_args_list[2][0][0] == "PRAGMA foreign_keys=ON"
assert len(execute_args) == 4
assert execute_args[0] == "PRAGMA journal_mode=WAL"
assert execute_args[1] == "SELECT sqlite_version()"
assert execute_args[2] == "PRAGMA cache_size = -8192"
assert execute_args[3] == "PRAGMA foreign_keys=ON"
execute_mock.reset_mock()
util.setup_connection_for_dialect("sqlite", dbapi_connection, False)
execute_args = []
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, False)
assert len(execute_mock.call_args_list) == 2
assert execute_mock.call_args_list[0][0][0] == "PRAGMA cache_size = -8192"
assert execute_mock.call_args_list[1][0][0] == "PRAGMA foreign_keys=ON"
assert len(execute_args) == 2
assert execute_args[0] == "PRAGMA cache_size = -8192"
assert execute_args[1] == "PRAGMA foreign_keys=ON"
assert instance_mock._db_supports_row_number == db_supports_row_number
def test_basic_sanity_check(hass_recorder):

View file

@ -1806,7 +1806,13 @@ def test_compile_hourly_statistics_changing_statistics(
assert "Error while processing event StatisticsTask" not in caplog.text
def test_compile_statistics_hourly_summary(hass_recorder, caplog):
@pytest.mark.parametrize(
"db_supports_row_number,in_log,not_in_log",
[(True, "row_number", None), (False, None, "row_number")],
)
def test_compile_statistics_hourly_summary(
hass_recorder, caplog, db_supports_row_number, in_log, not_in_log
):
"""Test compiling hourly statistics."""
zero = dt_util.utcnow()
zero = zero.replace(minute=0, second=0, microsecond=0)
@ -1815,6 +1821,7 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
zero += timedelta(hours=1)
hass = hass_recorder()
recorder = hass.data[DATA_INSTANCE]
recorder._db_supports_row_number = db_supports_row_number
setup_component(hass, "sensor", {})
attributes = {
"device_class": None,
@ -2052,6 +2059,10 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
end += timedelta(hours=1)
assert stats == expected_stats
assert "Error while processing event StatisticsTask" not in caplog.text
if in_log:
assert in_log in caplog.text
if not_in_log:
assert not_in_log not in caplog.text
def record_states(hass, zero, entity_id, attributes, seq=None):