Remove support for databases without ROW_NUMBER (#72092)

This commit is contained in:
Erik Montnemery 2022-05-19 04:52:38 +02:00 committed by GitHub
parent 3a13ffcf13
commit edd7a3427c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 130 deletions

View file

@ -180,7 +180,6 @@ class Recorder(threading.Thread):
self._completed_first_database_setup: bool | None = None
self.async_migration_event = asyncio.Event()
self.migration_in_progress = False
self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self._exclude_attributes_by_domain = exclude_attributes_by_domain

View file

@ -437,22 +437,6 @@ def _compile_hourly_statistics_summary_mean_stmt(
return stmt
def _compile_hourly_statistics_summary_sum_legacy_stmt(
start_time: datetime, end_time: datetime
) -> StatementLambdaElement:
"""Generate the legacy sum statement for hourly statistics.
This is used for databases not supporting row number.
"""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY))
stmt += (
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time)
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
)
return stmt
def compile_hourly_statistics(
instance: Recorder, session: Session, start: datetime
) -> None:
@ -481,66 +465,37 @@ def compile_hourly_statistics(
}
# Get last hour's last sum
if instance._db_supports_row_number: # pylint: disable=[protected-access]
subquery = (
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
.filter(StatisticsShortTerm.start < bindparam("end_time"))
.subquery()
)
query = (
session.query(subquery)
.filter(subquery.c.rownum == 1)
.order_by(subquery.c.metadata_id)
)
stats = execute(query.params(start_time=start_time, end_time=end_time))
subquery = (
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
.filter(StatisticsShortTerm.start < bindparam("end_time"))
.subquery()
)
query = (
session.query(subquery)
.filter(subquery.c.rownum == 1)
.order_by(subquery.c.metadata_id)
)
stats = execute(query.params(start_time=start_time, end_time=end_time))
if stats:
for stat in stats:
metadata_id, start, last_reset, state, _sum, _ = stat
if metadata_id in summary:
summary[metadata_id].update(
{
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
else:
stmt = _compile_hourly_statistics_summary_sum_legacy_stmt(start_time, end_time)
stats = execute_stmt_lambda_element(session, stmt)
if stats:
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore[no-any-return]
(
metadata_id,
last_reset,
state,
_sum,
) = next(group)
if metadata_id in summary:
summary[metadata_id].update(
{
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
if stats:
for stat in stats:
metadata_id, start, last_reset, state, _sum, _ = stat
if metadata_id in summary:
summary[metadata_id].update(
{
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
# Insert compiled hourly statistics in the database
for metadata_id, stat in summary.items():

View file

@ -52,12 +52,9 @@ SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
DEFAULT_YIELD_STATES_ROWS = 32768
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MYSQL = AwesomeVersion("8.0.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MYSQL_ROWNUM = AwesomeVersion("5.8.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_PGSQL = AwesomeVersion("12.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_SQLITE = AwesomeVersion("3.31.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_SQLITE_ROWNUM = AwesomeVersion("3.25.0", AwesomeVersionStrategy.SIMPLEVER)
# This is the maximum time after the recorder ends the session
# before we no longer consider startup to be a "restart" and we
@ -414,10 +411,6 @@ def setup_connection_for_dialect(
version_string = result[0][0]
version = _extract_version_from_server_response(version_string)
if version and version < MIN_VERSION_SQLITE_ROWNUM:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
if not version or version < MIN_VERSION_SQLITE:
_fail_unsupported_version(
version or version_string, "SQLite", MIN_VERSION_SQLITE
@ -448,19 +441,11 @@ def setup_connection_for_dialect(
is_maria_db = "mariadb" in version_string.lower()
if is_maria_db:
if version and version < MIN_VERSION_MARIA_DB_ROWNUM:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
if not version or version < MIN_VERSION_MARIA_DB:
_fail_unsupported_version(
version or version_string, "MariaDB", MIN_VERSION_MARIA_DB
)
else:
if version and version < MIN_VERSION_MYSQL_ROWNUM:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
if not version or version < MIN_VERSION_MYSQL:
_fail_unsupported_version(
version or version_string, "MySQL", MIN_VERSION_MYSQL

View file

@ -166,15 +166,12 @@ async def test_last_run_was_recently_clean(
@pytest.mark.parametrize(
"mysql_version, db_supports_row_number",
[
("10.3.0-MariaDB", True),
("8.0.0", True),
],
"mysql_version",
["10.3.0-MariaDB", "8.0.0"],
)
def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number):
def test_setup_connection_for_dialect_mysql(mysql_version):
"""Test setting up the connection for a mysql dialect."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -199,18 +196,14 @@ def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_numbe
assert execute_args[0] == "SET session wait_timeout=28800"
assert execute_args[1] == "SELECT VERSION()"
assert instance_mock._db_supports_row_number == db_supports_row_number
@pytest.mark.parametrize(
"sqlite_version, db_supports_row_number",
[
("3.31.0", True),
],
"sqlite_version",
["3.31.0"],
)
def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number):
def test_setup_connection_for_dialect_sqlite(sqlite_version):
"""Test setting up the connection for a sqlite dialect."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -246,20 +239,16 @@ def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_num
assert execute_args[1] == "PRAGMA synchronous=NORMAL"
assert execute_args[2] == "PRAGMA foreign_keys=ON"
assert instance_mock._db_supports_row_number == db_supports_row_number
@pytest.mark.parametrize(
"sqlite_version, db_supports_row_number",
[
("3.31.0", True),
],
"sqlite_version",
["3.31.0"],
)
def test_setup_connection_for_dialect_sqlite_zero_commit_interval(
sqlite_version, db_supports_row_number
sqlite_version,
):
"""Test setting up the connection for a sqlite dialect with a zero commit interval."""
instance_mock = MagicMock(_db_supports_row_number=True, commit_interval=0)
instance_mock = MagicMock(commit_interval=0)
execute_args = []
close_mock = MagicMock()
@ -295,8 +284,6 @@ def test_setup_connection_for_dialect_sqlite_zero_commit_interval(
assert execute_args[1] == "PRAGMA synchronous=FULL"
assert execute_args[2] == "PRAGMA foreign_keys=ON"
assert instance_mock._db_supports_row_number == db_supports_row_number
@pytest.mark.parametrize(
"mysql_version,message",
@ -317,7 +304,7 @@ def test_setup_connection_for_dialect_sqlite_zero_commit_interval(
)
def test_fail_outdated_mysql(caplog, mysql_version, message):
"""Test setting up the connection for an outdated mysql version."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -353,7 +340,7 @@ def test_fail_outdated_mysql(caplog, mysql_version, message):
)
def test_supported_mysql(caplog, mysql_version):
"""Test setting up the connection for a supported mysql version."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -396,7 +383,7 @@ def test_supported_mysql(caplog, mysql_version):
)
def test_fail_outdated_pgsql(caplog, pgsql_version, message):
"""Test setting up the connection for an outdated PostgreSQL version."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -429,7 +416,7 @@ def test_fail_outdated_pgsql(caplog, pgsql_version, message):
)
def test_supported_pgsql(caplog, pgsql_version):
"""Test setting up the connection for a supported PostgreSQL version."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -474,7 +461,7 @@ def test_supported_pgsql(caplog, pgsql_version):
)
def test_fail_outdated_sqlite(caplog, sqlite_version, message):
"""Test setting up the connection for an outdated sqlite version."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()
@ -510,7 +497,7 @@ def test_fail_outdated_sqlite(caplog, sqlite_version, message):
)
def test_supported_sqlite(caplog, sqlite_version):
"""Test setting up the connection for a supported sqlite version."""
instance_mock = MagicMock(_db_supports_row_number=True)
instance_mock = MagicMock()
execute_args = []
close_mock = MagicMock()

View file

@ -2279,13 +2279,7 @@ def test_compile_hourly_statistics_changing_statistics(
assert "Error while processing event StatisticsTask" not in caplog.text
@pytest.mark.parametrize(
"db_supports_row_number,in_log,not_in_log",
[(True, "row_number", None), (False, None, "row_number")],
)
def test_compile_statistics_hourly_daily_monthly_summary(
hass_recorder, caplog, db_supports_row_number, in_log, not_in_log
):
def test_compile_statistics_hourly_daily_monthly_summary(hass_recorder, caplog):
"""Test compiling hourly statistics + monthly and daily summary."""
zero = dt_util.utcnow()
# August 31st, 23:00 local time
@ -2299,7 +2293,6 @@ def test_compile_statistics_hourly_daily_monthly_summary(
# Remove this after dropping the use of the hass_recorder fixture
hass.config.set_time_zone("America/Regina")
recorder = hass.data[DATA_INSTANCE]
recorder._db_supports_row_number = db_supports_row_number
setup_component(hass, "sensor", {})
wait_recording_done(hass) # Wait for the sensor recorder platform to be added
attributes = {
@ -2693,10 +2686,6 @@ def test_compile_statistics_hourly_daily_monthly_summary(
assert stats == expected_stats
assert "Error while processing event StatisticsTask" not in caplog.text
if in_log:
assert in_log in caplog.text
if not_in_log:
assert not_in_log not in caplog.text
def record_states(hass, zero, entity_id, attributes, seq=None):