From 1acdb28cdd6cc0badd50997698ac6d4fa63879f3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 17 Jul 2020 19:07:37 -1000 Subject: [PATCH] Automatically recover when the sqlite3 database is malformed or corrupted (#37949) * Validate sqlite database on startup and move away if corruption is detected. * do not switch context in test -- its all sync --- homeassistant/components/recorder/__init__.py | 13 +++-- homeassistant/components/recorder/const.py | 1 + homeassistant/components/recorder/util.py | 56 ++++++++++++++++++- tests/components/recorder/test_util.py | 32 ++++++++++- 4 files changed, 96 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index c64b9429cf0..5ac4d226082 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -35,9 +35,9 @@ from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util from . import migration, purge -from .const import DATA_INSTANCE +from .const import DATA_INSTANCE, SQLITE_URL_PREFIX from .models import Base, Events, RecorderRuns, States -from .util import session_scope +from .util import session_scope, validate_or_move_away_sqlite_database _LOGGER = logging.getLogger(__name__) @@ -510,7 +510,7 @@ class Recorder(threading.Thread): # We do not import sqlite3 here so mysql/other # users do not have to pay for it to be loaded in # memory - if self.db_url.startswith("sqlite://"): + if self.db_url.startswith(SQLITE_URL_PREFIX): old_isolation = dbapi_connection.isolation_level dbapi_connection.isolation_level = None cursor = dbapi_connection.cursor() @@ -526,13 +526,18 @@ class Recorder(threading.Thread): cursor.execute("SET session wait_timeout=28800") cursor.close() - if self.db_url == "sqlite://" or ":memory:" in self.db_url: + if self.db_url == SQLITE_URL_PREFIX or ":memory:" in self.db_url: kwargs["connect_args"] = {"check_same_thread": False} kwargs["poolclass"] = StaticPool kwargs["pool_reset_on_return"] = None else: kwargs["echo"] = False + if self.db_url != SQLITE_URL_PREFIX and self.db_url.startswith( + SQLITE_URL_PREFIX + ): + validate_or_move_away_sqlite_database(self.db_url) + if self.engine is not None: self.engine.dispose() diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index ed0950b6c6f..fb699d13fb3 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -1,3 +1,4 @@ """Recorder constants.""" DATA_INSTANCE = "recorder_instance" +SQLITE_URL_PREFIX = "sqlite://" diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 883bc41e71b..8a59cc42a33 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -1,16 +1,20 @@ """SQLAlchemy util functions.""" from contextlib import contextmanager import logging +import os import time from sqlalchemy.exc import OperationalError, SQLAlchemyError -from .const import DATA_INSTANCE +import homeassistant.util.dt as dt_util + +from .const import DATA_INSTANCE, SQLITE_URL_PREFIX _LOGGER = logging.getLogger(__name__) RETRIES = 3 QUERY_RETRY_WAIT = 0.1 +SQLITE3_POSTFIXES = ["", "-wal", "-shm"] @contextmanager @@ -59,6 +63,7 @@ def execute(qry, to_native=False, validate_entity_ids=True): This method also retries a few times in the case of stale connections. """ + for tryno in range(0, RETRIES): try: timer_start = time.perf_counter() @@ -94,3 +99,52 @@ def execute(qry, to_native=False, validate_entity_ids=True): if tryno == RETRIES - 1: raise time.sleep(QUERY_RETRY_WAIT) + + +def validate_or_move_away_sqlite_database(dburl: str) -> bool: + """Ensure that the database is valid or move it away.""" + dbpath = dburl[len(SQLITE_URL_PREFIX) :] + + if not os.path.exists(dbpath): + # Database does not exist yet, this is OK + return True + + if not validate_sqlite_database(dbpath): + _move_away_broken_database(dbpath) + return False + + return True + + +def validate_sqlite_database(dbpath: str) -> bool: + """Run a quick check on an sqlite database to see if it is corrupt.""" + import sqlite3 # pylint: disable=import-outside-toplevel + + try: + conn = sqlite3.connect(dbpath) + conn.cursor().execute("PRAGMA QUICK_CHECK") + conn.close() + except sqlite3.DatabaseError: + _LOGGER.exception("The database at %s is corrupt or malformed.", dbpath) + return False + + return True + + +def _move_away_broken_database(dbfile: str) -> None: + """Move away a broken sqlite3 database.""" + + isotime = dt_util.utcnow().isoformat() + corrupt_postfix = f".corrupt.{isotime}" + + _LOGGER.error( + "The system will rename the corrupt database file %s to %s in order to allow startup to proceed", + dbfile, + f"{dbfile}{corrupt_postfix}", + ) + + for postfix in SQLITE3_POSTFIXES: + path = f"{dbfile}{postfix}" + if not os.path.exists(path): + continue + os.rename(path, f"{path}{corrupt_postfix}") diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 6a4126e76fd..56f1e069a61 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -1,8 +1,10 @@ """Test util methods.""" +import os + import pytest from homeassistant.components.recorder import util -from homeassistant.components.recorder.const import DATA_INSTANCE +from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from tests.async_mock import MagicMock, patch from tests.common import get_test_home_assistant, init_recorder_component @@ -60,3 +62,31 @@ def test_recorder_bad_execute(hass_recorder): util.execute((mck1,), to_native=True) assert e_mock.call_count == 2 + + +def test_validate_or_move_away_sqlite_database(hass, tmpdir, caplog): + """Ensure a malformed sqlite database is moved away.""" + + test_dir = tmpdir.mkdir("test_validate_or_move_away_sqlite_database") + test_db_file = f"{test_dir}/broken.db" + dburl = f"{SQLITE_URL_PREFIX}{test_db_file}" + + util.validate_sqlite_database(test_db_file) is True + + assert os.path.exists(test_db_file) is True + assert util.validate_or_move_away_sqlite_database(dburl) is True + + _corrupt_db_file(test_db_file) + + assert util.validate_or_move_away_sqlite_database(dburl) is False + + assert "corrupt or malformed" in caplog.text + + assert util.validate_or_move_away_sqlite_database(dburl) is True + + +def _corrupt_db_file(test_db_file): + """Corrupt an sqlite3 database file.""" + f = open(test_db_file, "a") + f.write("I am a corrupt db") + f.close()