From 0139bfa7497cfed160eeaf6016c50cb99e953361 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 12 Oct 2021 06:17:18 +0200 Subject: [PATCH] Detect if mysql and sqlite support row_number (#57475) --- homeassistant/components/recorder/__init__.py | 2 + .../components/recorder/statistics.py | 103 +++++++++++++----- homeassistant/components/recorder/util.py | 28 ++++- tests/components/recorder/test_util.py | 80 +++++++++++--- tests/components/sensor/test_recorder.py | 13 ++- 5 files changed, 180 insertions(+), 46 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 1b090c331a7..7e9bab0ed4e 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -413,6 +413,7 @@ class Recorder(threading.Thread): self.async_migration_event = asyncio.Event() self.migration_in_progress = False self._queue_watcher = None + self._db_supports_row_number = True self.enabled = True @@ -972,6 +973,7 @@ class Recorder(threading.Thread): def setup_recorder_connection(dbapi_connection, connection_record): """Dbapi specific connection settings.""" setup_connection_for_dialect( + self, self.engine.dialect.name, dbapi_connection, not self._completed_first_database_setup, diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index d253d1e2275..200da8d192d 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -89,6 +89,13 @@ QUERY_STATISTICS_SUMMARY_SUM = [ .label("rownum"), ] +QUERY_STATISTICS_SUMMARY_SUM_LEGACY = [ + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.last_reset, + StatisticsShortTerm.state, + StatisticsShortTerm.sum, +] + QUERY_STATISTIC_META = [ StatisticsMeta.id, StatisticsMeta.statistic_id, @@ -275,37 +282,81 @@ def compile_hourly_statistics( } # Get last hour's last sum - subquery = ( - session.query(*QUERY_STATISTICS_SUMMARY_SUM) - .filter(StatisticsShortTerm.start >= bindparam("start_time")) - .filter(StatisticsShortTerm.start < bindparam("end_time")) - .subquery() - ) - query = ( - session.query(subquery) - .filter(subquery.c.rownum == 1) - .order_by(subquery.c.metadata_id) - ) - stats = execute(query.params(start_time=start_time, end_time=end_time)) + if instance._db_supports_row_number: # pylint: disable=[protected-access] + subquery = ( + session.query(*QUERY_STATISTICS_SUMMARY_SUM) + .filter(StatisticsShortTerm.start >= bindparam("start_time")) + .filter(StatisticsShortTerm.start < bindparam("end_time")) + .subquery() + ) + query = ( + session.query(subquery) + .filter(subquery.c.rownum == 1) + .order_by(subquery.c.metadata_id) + ) + stats = execute(query.params(start_time=start_time, end_time=end_time)) - if stats: - for stat in stats: - metadata_id, start, last_reset, state, _sum, _ = stat - if metadata_id in summary: - summary[metadata_id].update( - { + if stats: + for stat in stats: + metadata_id, start, last_reset, state, _sum, _ = stat + if metadata_id in summary: + summary[metadata_id].update( + { + "last_reset": process_timestamp(last_reset), + "state": state, + "sum": _sum, + } + ) + else: + summary[metadata_id] = { + "start": start_time, + "last_reset": process_timestamp(last_reset), + "state": state, + "sum": _sum, + } + else: + baked_query = instance.hass.data[STATISTICS_SHORT_TERM_BAKERY]( + lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY) + ) + + baked_query += lambda q: q.filter( + StatisticsShortTerm.start >= bindparam("start_time") + ) + baked_query += lambda q: q.filter( + StatisticsShortTerm.start < bindparam("end_time") + ) + baked_query += lambda q: q.order_by( + StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc() + ) + + stats = execute( + baked_query(session).params(start_time=start_time, end_time=end_time) + ) + + if stats: + for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore + ( + metadata_id, + last_reset, + state, + _sum, + ) = next(group) + if metadata_id in summary: + summary[metadata_id].update( + { + "start": start_time, + "last_reset": process_timestamp(last_reset), + "state": state, + "sum": _sum, + } + ) + else: + summary[metadata_id] = { + "start": start_time, "last_reset": process_timestamp(last_reset), "state": state, "sum": _sum, } - ) - else: - summary[metadata_id] = { - "start": start_time, - "last_reset": process_timestamp(last_reset), - "state": state, - "sum": _sum, - } # Insert compiled hourly statistics in the database for metadata_id, stat in summary.items(): diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 101915c7117..567164d4325 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -266,7 +266,18 @@ def execute_on_connection(dbapi_connection, statement): cursor.close() -def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connection): +def query_on_connection(dbapi_connection, statement): + """Execute a single statement with a dbapi connection and return the result.""" + cursor = dbapi_connection.cursor() + cursor.execute(statement) + result = cursor.fetchall() + cursor.close() + return result + + +def setup_connection_for_dialect( + instance, dialect_name, dbapi_connection, first_connection +): """Execute statements needed for dialect connection.""" # Returns False if the the connection needs to be setup # on the next connection, returns True if the connection @@ -280,6 +291,13 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio # WAL mode only needs to be setup once # instead of every time we open the sqlite connection # as its persistent and isn't free to call every time. + result = query_on_connection(dbapi_connection, "SELECT sqlite_version()") + version = result[0][0] + major, minor, _patch = version.split(".", 2) + if int(major) == 3 and int(minor) < 25: + instance._db_supports_row_number = ( # pylint: disable=[protected-access] + False + ) # approximately 8MiB of memory execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192") @@ -289,6 +307,14 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio if dialect_name == "mysql": execute_on_connection(dbapi_connection, "SET session wait_timeout=28800") + if first_connection: + result = query_on_connection(dbapi_connection, "SELECT VERSION()") + version = result[0][0] + major, minor, _patch = version.split(".", 2) + if int(major) == 5 and int(minor) < 8: + instance._db_supports_row_number = ( # pylint: disable=[protected-access] + False + ) def end_incomplete_runs(session, start_time): diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index f193993ffe5..8b5de5cff16 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -122,44 +122,88 @@ async def test_last_run_was_recently_clean(hass): ) -def test_setup_connection_for_dialect_mysql(): +@pytest.mark.parametrize( + "mysql_version, db_supports_row_number", + [ + ("10.0.0", True), + ("5.8.0", True), + ("5.7.0", False), + ], +) +def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number): """Test setting up the connection for a mysql dialect.""" - execute_mock = MagicMock() + instance_mock = MagicMock(_db_supports_row_number=True) + execute_args = [] close_mock = MagicMock() + def execute_mock(statement): + nonlocal execute_args + execute_args.append(statement) + + def fetchall_mock(): + nonlocal execute_args + if execute_args[-1] == "SELECT VERSION()": + return [[mysql_version]] + return None + def _make_cursor_mock(*_): - return MagicMock(execute=execute_mock, close=close_mock) + return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock) dbapi_connection = MagicMock(cursor=_make_cursor_mock) - util.setup_connection_for_dialect("mysql", dbapi_connection, True) + util.setup_connection_for_dialect(instance_mock, "mysql", dbapi_connection, True) - assert execute_mock.call_args[0][0] == "SET session wait_timeout=28800" + assert len(execute_args) == 2 + assert execute_args[0] == "SET session wait_timeout=28800" + assert execute_args[1] == "SELECT VERSION()" + + assert instance_mock._db_supports_row_number == db_supports_row_number -def test_setup_connection_for_dialect_sqlite(): +@pytest.mark.parametrize( + "sqlite_version, db_supports_row_number", + [ + ("3.25.0", True), + ("3.24.0", False), + ], +) +def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number): """Test setting up the connection for a sqlite dialect.""" - execute_mock = MagicMock() + instance_mock = MagicMock(_db_supports_row_number=True) + execute_args = [] close_mock = MagicMock() + def execute_mock(statement): + nonlocal execute_args + execute_args.append(statement) + + def fetchall_mock(): + nonlocal execute_args + if execute_args[-1] == "SELECT sqlite_version()": + return [[sqlite_version]] + return None + def _make_cursor_mock(*_): - return MagicMock(execute=execute_mock, close=close_mock) + return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock) dbapi_connection = MagicMock(cursor=_make_cursor_mock) - util.setup_connection_for_dialect("sqlite", dbapi_connection, True) + util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, True) - assert len(execute_mock.call_args_list) == 3 - assert execute_mock.call_args_list[0][0][0] == "PRAGMA journal_mode=WAL" - assert execute_mock.call_args_list[1][0][0] == "PRAGMA cache_size = -8192" - assert execute_mock.call_args_list[2][0][0] == "PRAGMA foreign_keys=ON" + assert len(execute_args) == 4 + assert execute_args[0] == "PRAGMA journal_mode=WAL" + assert execute_args[1] == "SELECT sqlite_version()" + assert execute_args[2] == "PRAGMA cache_size = -8192" + assert execute_args[3] == "PRAGMA foreign_keys=ON" - execute_mock.reset_mock() - util.setup_connection_for_dialect("sqlite", dbapi_connection, False) + execute_args = [] + util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, False) - assert len(execute_mock.call_args_list) == 2 - assert execute_mock.call_args_list[0][0][0] == "PRAGMA cache_size = -8192" - assert execute_mock.call_args_list[1][0][0] == "PRAGMA foreign_keys=ON" + assert len(execute_args) == 2 + assert execute_args[0] == "PRAGMA cache_size = -8192" + assert execute_args[1] == "PRAGMA foreign_keys=ON" + + assert instance_mock._db_supports_row_number == db_supports_row_number def test_basic_sanity_check(hass_recorder): diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index 9ae4b467da5..8a0da39cde3 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -1806,7 +1806,13 @@ def test_compile_hourly_statistics_changing_statistics( assert "Error while processing event StatisticsTask" not in caplog.text -def test_compile_statistics_hourly_summary(hass_recorder, caplog): +@pytest.mark.parametrize( + "db_supports_row_number,in_log,not_in_log", + [(True, "row_number", None), (False, None, "row_number")], +) +def test_compile_statistics_hourly_summary( + hass_recorder, caplog, db_supports_row_number, in_log, not_in_log +): """Test compiling hourly statistics.""" zero = dt_util.utcnow() zero = zero.replace(minute=0, second=0, microsecond=0) @@ -1815,6 +1821,7 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog): zero += timedelta(hours=1) hass = hass_recorder() recorder = hass.data[DATA_INSTANCE] + recorder._db_supports_row_number = db_supports_row_number setup_component(hass, "sensor", {}) attributes = { "device_class": None, @@ -2052,6 +2059,10 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog): end += timedelta(hours=1) assert stats == expected_stats assert "Error while processing event StatisticsTask" not in caplog.text + if in_log: + assert in_log in caplog.text + if not_in_log: + assert not_in_log not in caplog.text def record_states(hass, zero, entity_id, attributes, seq=None):