diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index d850778d214..ad68e415df5 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -1,10 +1,12 @@ """Test util methods.""" +from contextlib import AbstractContextManager, nullcontext as does_not_raise from datetime import UTC, datetime, timedelta import os from pathlib import Path import sqlite3 import threading +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -16,7 +18,11 @@ from sqlalchemy.sql.lambdas import StatementLambdaElement from homeassistant.components import recorder from homeassistant.components.recorder import Recorder, util -from homeassistant.components.recorder.const import DOMAIN, SQLITE_URL_PREFIX +from homeassistant.components.recorder.const import ( + DOMAIN, + SQLITE_URL_PREFIX, + SupportedDialect, +) from homeassistant.components.recorder.db_schema import RecorderRuns from homeassistant.components.recorder.history.modern import ( _get_single_entity_start_time_stmt, @@ -27,10 +33,14 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.components.recorder.util import ( MIN_VERSION_SQLITE, + RETRYABLE_MYSQL_ERRORS, UPCOMING_MIN_VERSION_SQLITE, + database_job_retry_wrapper, end_incomplete_runs, is_second_sunday, resolve_period, + retryable_database_job, + retryable_database_job_method, session_scope, ) from homeassistant.const import EVENT_HOMEASSISTANT_STOP @@ -1117,3 +1127,115 @@ async def test_resolve_period(hass: HomeAssistant) -> None: } } ) == (now - timedelta(hours=1, minutes=25), now - timedelta(minutes=25)) + + +NonRetryable = OperationalError(None, None, BaseException()) +Retryable = OperationalError(None, None, BaseException(RETRYABLE_MYSQL_ERRORS[0], "")) + + +@pytest.mark.parametrize( + ("side_effect", "dialect", "expected_result", "num_calls"), + [ + (None, SupportedDialect.MYSQL, does_not_raise(), 1), + (ValueError, SupportedDialect.MYSQL, pytest.raises(ValueError), 1), + (NonRetryable, SupportedDialect.MYSQL, pytest.raises(OperationalError), 1), + (Retryable, SupportedDialect.MYSQL, pytest.raises(OperationalError), 5), + (NonRetryable, SupportedDialect.SQLITE, pytest.raises(OperationalError), 1), + (Retryable, SupportedDialect.SQLITE, pytest.raises(OperationalError), 1), + ], +) +def test_database_job_retry_wrapper( + side_effect: Any, + dialect: str, + expected_result: AbstractContextManager, + num_calls: int, +) -> None: + """Test database_job_retry_wrapper.""" + + instance = Mock() + instance.db_retry_wait = 0 + instance.engine.dialect.name = dialect + mock_job = Mock(side_effect=side_effect) + + @database_job_retry_wrapper(description="test") + def job(instance, *args, **kwargs) -> None: + mock_job() + + with expected_result: + job(instance) + + assert len(mock_job.mock_calls) == num_calls + + +@pytest.mark.parametrize( + ("side_effect", "dialect", "retval", "expected_result"), + [ + (None, SupportedDialect.MYSQL, False, does_not_raise()), + (None, SupportedDialect.MYSQL, True, does_not_raise()), + (ValueError, SupportedDialect.MYSQL, False, pytest.raises(ValueError)), + (NonRetryable, SupportedDialect.MYSQL, True, does_not_raise()), + (Retryable, SupportedDialect.MYSQL, False, does_not_raise()), + (NonRetryable, SupportedDialect.SQLITE, True, does_not_raise()), + (Retryable, SupportedDialect.SQLITE, True, does_not_raise()), + ], +) +def test_retryable_database_job( + side_effect: Any, + retval: bool, + expected_result: AbstractContextManager, + dialect: str, +) -> None: + """Test retryable_database_job.""" + + instance = Mock() + instance.db_retry_wait = 0 + instance.engine.dialect.name = dialect + mock_job = Mock(side_effect=side_effect) + + @retryable_database_job(description="test") + def job(instance, *args, **kwargs) -> bool: + mock_job() + return retval + + with expected_result: + assert job(instance) == retval + + assert len(mock_job.mock_calls) == 1 + + +@pytest.mark.parametrize( + ("side_effect", "dialect", "retval", "expected_result"), + [ + (None, SupportedDialect.MYSQL, False, does_not_raise()), + (None, SupportedDialect.MYSQL, True, does_not_raise()), + (ValueError, SupportedDialect.MYSQL, False, pytest.raises(ValueError)), + (NonRetryable, SupportedDialect.MYSQL, True, does_not_raise()), + (Retryable, SupportedDialect.MYSQL, False, does_not_raise()), + (NonRetryable, SupportedDialect.SQLITE, True, does_not_raise()), + (Retryable, SupportedDialect.SQLITE, True, does_not_raise()), + ], +) +def test_retryable_database_job_method( + side_effect: Any, + retval: bool, + expected_result: AbstractContextManager, + dialect: str, +) -> None: + """Test retryable_database_job_method.""" + + instance = Mock() + instance.db_retry_wait = 0 + instance.engine.dialect.name = dialect + mock_job = Mock(side_effect=side_effect) + + class Test: + @retryable_database_job_method(description="test") + def job(self, instance, *args, **kwargs) -> bool: + mock_job() + return retval + + test = Test() + with expected_result: + assert test.job(instance) == retval + + assert len(mock_job.mock_calls) == 1