Automatically recover when the sqlite3 database is malformed or corrupted (#37949)

* Validate sqlite database on startup and move away if corruption is detected.

* do not switch context in test -- its all sync
This commit is contained in:
J. Nick Koston 2020-07-17 19:07:37 -10:00 committed by GitHub
parent 910b6c9c2c
commit 1acdb28cdd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 6 deletions

View file

@ -35,9 +35,9 @@ from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from . import migration, purge from . import migration, purge
from .const import DATA_INSTANCE from .const import DATA_INSTANCE, SQLITE_URL_PREFIX
from .models import Base, Events, RecorderRuns, States from .models import Base, Events, RecorderRuns, States
from .util import session_scope from .util import session_scope, validate_or_move_away_sqlite_database
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -510,7 +510,7 @@ class Recorder(threading.Thread):
# We do not import sqlite3 here so mysql/other # We do not import sqlite3 here so mysql/other
# users do not have to pay for it to be loaded in # users do not have to pay for it to be loaded in
# memory # memory
if self.db_url.startswith("sqlite://"): if self.db_url.startswith(SQLITE_URL_PREFIX):
old_isolation = dbapi_connection.isolation_level old_isolation = dbapi_connection.isolation_level
dbapi_connection.isolation_level = None dbapi_connection.isolation_level = None
cursor = dbapi_connection.cursor() cursor = dbapi_connection.cursor()
@ -526,13 +526,18 @@ class Recorder(threading.Thread):
cursor.execute("SET session wait_timeout=28800") cursor.execute("SET session wait_timeout=28800")
cursor.close() cursor.close()
if self.db_url == "sqlite://" or ":memory:" in self.db_url: if self.db_url == SQLITE_URL_PREFIX or ":memory:" in self.db_url:
kwargs["connect_args"] = {"check_same_thread": False} kwargs["connect_args"] = {"check_same_thread": False}
kwargs["poolclass"] = StaticPool kwargs["poolclass"] = StaticPool
kwargs["pool_reset_on_return"] = None kwargs["pool_reset_on_return"] = None
else: else:
kwargs["echo"] = False kwargs["echo"] = False
if self.db_url != SQLITE_URL_PREFIX and self.db_url.startswith(
SQLITE_URL_PREFIX
):
validate_or_move_away_sqlite_database(self.db_url)
if self.engine is not None: if self.engine is not None:
self.engine.dispose() self.engine.dispose()

View file

@ -1,3 +1,4 @@
"""Recorder constants.""" """Recorder constants."""
DATA_INSTANCE = "recorder_instance" DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"

View file

@ -1,16 +1,20 @@
"""SQLAlchemy util functions.""" """SQLAlchemy util functions."""
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import os
import time import time
from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.exc import OperationalError, SQLAlchemyError
from .const import DATA_INSTANCE import homeassistant.util.dt as dt_util
from .const import DATA_INSTANCE, SQLITE_URL_PREFIX
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
RETRIES = 3 RETRIES = 3
QUERY_RETRY_WAIT = 0.1 QUERY_RETRY_WAIT = 0.1
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
@contextmanager @contextmanager
@ -59,6 +63,7 @@ def execute(qry, to_native=False, validate_entity_ids=True):
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
""" """
for tryno in range(0, RETRIES): for tryno in range(0, RETRIES):
try: try:
timer_start = time.perf_counter() timer_start = time.perf_counter()
@ -94,3 +99,52 @@ def execute(qry, to_native=False, validate_entity_ids=True):
if tryno == RETRIES - 1: if tryno == RETRIES - 1:
raise raise
time.sleep(QUERY_RETRY_WAIT) time.sleep(QUERY_RETRY_WAIT)
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
"""Ensure that the database is valid or move it away."""
dbpath = dburl[len(SQLITE_URL_PREFIX) :]
if not os.path.exists(dbpath):
# Database does not exist yet, this is OK
return True
if not validate_sqlite_database(dbpath):
_move_away_broken_database(dbpath)
return False
return True
def validate_sqlite_database(dbpath: str) -> bool:
"""Run a quick check on an sqlite database to see if it is corrupt."""
import sqlite3 # pylint: disable=import-outside-toplevel
try:
conn = sqlite3.connect(dbpath)
conn.cursor().execute("PRAGMA QUICK_CHECK")
conn.close()
except sqlite3.DatabaseError:
_LOGGER.exception("The database at %s is corrupt or malformed.", dbpath)
return False
return True
def _move_away_broken_database(dbfile: str) -> None:
"""Move away a broken sqlite3 database."""
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
_LOGGER.error(
"The system will rename the corrupt database file %s to %s in order to allow startup to proceed",
dbfile,
f"{dbfile}{corrupt_postfix}",
)
for postfix in SQLITE3_POSTFIXES:
path = f"{dbfile}{postfix}"
if not os.path.exists(path):
continue
os.rename(path, f"{path}{corrupt_postfix}")

View file

@ -1,8 +1,10 @@
"""Test util methods.""" """Test util methods."""
import os
import pytest import pytest
from homeassistant.components.recorder import util from homeassistant.components.recorder import util
from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from tests.async_mock import MagicMock, patch from tests.async_mock import MagicMock, patch
from tests.common import get_test_home_assistant, init_recorder_component from tests.common import get_test_home_assistant, init_recorder_component
@ -60,3 +62,31 @@ def test_recorder_bad_execute(hass_recorder):
util.execute((mck1,), to_native=True) util.execute((mck1,), to_native=True)
assert e_mock.call_count == 2 assert e_mock.call_count == 2
def test_validate_or_move_away_sqlite_database(hass, tmpdir, caplog):
"""Ensure a malformed sqlite database is moved away."""
test_dir = tmpdir.mkdir("test_validate_or_move_away_sqlite_database")
test_db_file = f"{test_dir}/broken.db"
dburl = f"{SQLITE_URL_PREFIX}{test_db_file}"
util.validate_sqlite_database(test_db_file) is True
assert os.path.exists(test_db_file) is True
assert util.validate_or_move_away_sqlite_database(dburl) is True
_corrupt_db_file(test_db_file)
assert util.validate_or_move_away_sqlite_database(dburl) is False
assert "corrupt or malformed" in caplog.text
assert util.validate_or_move_away_sqlite_database(dburl) is True
def _corrupt_db_file(test_db_file):
"""Corrupt an sqlite3 database file."""
f = open(test_db_file, "a")
f.write("I am a corrupt db")
f.close()