WIP: [component/recorder] Refactoring & better handling of SQLAlchemy Sessions (#5607)

* Refactor recorder and Sessions

* Cover #4352

* NO_reset_on_return

* contextmanager

* coverage
This commit is contained in:
Johann Kellerman 2017-02-08 07:47:41 +02:00 committed by Paulus Schoutsen
parent bdebe5d53c
commit 490ef6afad
3 changed files with 221 additions and 153 deletions

View file

@ -64,7 +64,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None,
""" """
entity_ids = (entity_id.lower(), ) if entity_id is not None else None entity_ids = (entity_id.lower(), ) if entity_id is not None else None
states = recorder.get_model('States') states = recorder.get_model('States')
query = recorder.query('States').filter( query = recorder.query(states).filter(
(states.domain.in_(SIGNIFICANT_DOMAINS) | (states.domain.in_(SIGNIFICANT_DOMAINS) |
(states.last_changed == states.last_updated)) & (states.last_changed == states.last_updated)) &
(states.last_updated > start_time)) (states.last_updated > start_time))

View file

@ -13,6 +13,7 @@ import threading
import time import time
from datetime import timedelta, datetime from datetime import timedelta, datetime
from typing import Any, Union, Optional, List, Dict from typing import Any, Union, Optional, List, Dict
from contextlib import contextmanager
import voluptuous as vol import voluptuous as vol
@ -22,7 +23,7 @@ from homeassistant.const import (
CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import track_point_in_utc_time from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType, QueryType from homeassistant.helpers.typing import ConfigType, QueryType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -39,6 +40,7 @@ CONF_PURGE_DAYS = 'purge_days'
RETRIES = 3 RETRIES = 3
CONNECT_RETRY_WAIT = 10 CONNECT_RETRY_WAIT = 10
QUERY_RETRY_WAIT = 0.1 QUERY_RETRY_WAIT = 0.1
ERROR_QUERY = "Error during query: %s"
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({ DOMAIN: vol.Schema({
@ -62,28 +64,43 @@ _INSTANCE = None # type: Any
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# These classes will be populated during setup() # These classes will be populated during setup()
# pylint: disable=invalid-name,no-member # scoped_session, in the same thread session_scope() stays the same
Session = None # pylint: disable=no-member _SESSION = None
@contextmanager
def session_scope():
"""Provide a transactional scope around a series of operations."""
session = _SESSION()
try:
yield session
session.commit()
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(ERROR_QUERY, err)
session.rollback()
raise
finally:
session.close()
# pylint: disable=invalid-sequence-index # pylint: disable=invalid-sequence-index
def execute(q: QueryType) -> List[Any]: def execute(qry: QueryType) -> List[Any]:
"""Query the database and convert the objects to HA native form. """Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
""" """
import sqlalchemy.exc import sqlalchemy.exc
try: with session_scope() as session:
for _ in range(0, RETRIES): for _ in range(0, RETRIES):
try: try:
return [ return [
row for row in row for row in
(row.to_native() for row in q) (row.to_native() for row in qry)
if row is not None] if row is not None]
except sqlalchemy.exc.SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as err:
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True) _LOGGER.error(ERROR_QUERY, err)
finally: session.rollback()
Session.close() time.sleep(QUERY_RETRY_WAIT)
return [] return []
@ -101,9 +118,10 @@ def run_information(point_in_time: Optional[datetime]=None):
start=_INSTANCE.recording_start, start=_INSTANCE.recording_start,
closed_incorrect=False) closed_incorrect=False)
return query('RecorderRuns').filter( with session_scope():
(recorder_runs.start < point_in_time) & return query('RecorderRuns').filter(
(recorder_runs.end > point_in_time)).first() (recorder_runs.start < point_in_time) &
(recorder_runs.end > point_in_time)).first()
def setup(hass: HomeAssistant, config: ConfigType) -> bool: def setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -132,10 +150,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
def query(model_name: Union[str, Any], *args) -> QueryType: def query(model_name: Union[str, Any], *args) -> QueryType:
"""Helper to return a query handle.""" """Helper to return a query handle."""
_verify_instance() _verify_instance()
if isinstance(model_name, str): if isinstance(model_name, str):
return Session.query(get_model(model_name), *args) return _SESSION().query(get_model(model_name), *args)
return Session.query(model_name, *args) return _SESSION().query(model_name, *args)
def get_model(model_name: str) -> Any: def get_model(model_name: str) -> Any:
@ -148,22 +165,6 @@ def get_model(model_name: str) -> Any:
return None return None
def log_error(e: Exception, retry_wait: Optional[float]=0,
rollback: Optional[bool]=True,
message: Optional[str]="Error during query: %s") -> None:
"""Log about SQLAlchemy errors in a sane manner."""
import sqlalchemy.exc
if not isinstance(e, sqlalchemy.exc.OperationalError):
_LOGGER.exception(str(e))
else:
_LOGGER.error(message, str(e))
if rollback:
Session.rollback()
if retry_wait:
_LOGGER.info("Retrying in %s seconds", retry_wait)
time.sleep(retry_wait)
class Recorder(threading.Thread): class Recorder(threading.Thread):
"""A threaded recorder class.""" """A threaded recorder class."""
@ -204,18 +205,14 @@ class Recorder(threading.Thread):
self._setup_connection() self._setup_connection()
self._setup_run() self._setup_run()
break break
except sqlalchemy.exc.SQLAlchemyError as e: except sqlalchemy.exc.SQLAlchemyError as err:
log_error(e, retry_wait=CONNECT_RETRY_WAIT, rollback=False, _LOGGER.error("Error during connection setup: %s (retrying "
message="Error during connection setup: %s") "in %s seconds)", err, CONNECT_RETRY_WAIT)
time.sleep(CONNECT_RETRY_WAIT)
if self.purge_days is not None: if self.purge_days is not None:
def purge_ticker(event): async_track_time_interval(
"""Rerun purge every second day.""" self.hass, self._purge_old_data, timedelta(days=2))
self._purge_old_data()
track_point_in_utc_time(self.hass, purge_ticker,
dt_util.utcnow() + timedelta(days=2))
track_point_in_utc_time(self.hass, purge_ticker,
dt_util.utcnow() + timedelta(minutes=5))
while True: while True:
event = self.queue.get() event = self.queue.get()
@ -250,16 +247,17 @@ class Recorder(threading.Thread):
self.queue.task_done() self.queue.task_done()
continue continue
dbevent = Events.from_event(event) with session_scope() as session:
self._commit(dbevent) dbevent = Events.from_event(event)
self._commit(session, dbevent)
if event.event_type != EVENT_STATE_CHANGED: if event.event_type != EVENT_STATE_CHANGED:
self.queue.task_done() self.queue.task_done()
continue continue
dbstate = States.from_event(event) dbstate = States.from_event(event)
dbstate.event_id = dbevent.event_id dbstate.event_id = dbevent.event_id
self._commit(dbstate) self._commit(session, dbstate)
self.queue.task_done() self.queue.task_done()
@ -282,11 +280,14 @@ class Recorder(threading.Thread):
def block_till_db_ready(self): def block_till_db_ready(self):
"""Block until the database session is ready.""" """Block until the database session is ready."""
self.db_ready.wait() self.db_ready.wait(10)
while not self.db_ready.is_set():
_LOGGER.warning('Database not ready, waiting another 10 seconds.')
self.db_ready.wait(10)
def _setup_connection(self): def _setup_connection(self):
"""Ensure database is ready to fly.""" """Ensure database is ready to fly."""
global Session # pylint: disable=global-statement global _SESSION # pylint: disable=invalid-name,global-statement
import homeassistant.components.recorder.models as models import homeassistant.components.recorder.models as models
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -298,40 +299,44 @@ class Recorder(threading.Thread):
self.engine = create_engine( self.engine = create_engine(
'sqlite://', 'sqlite://',
connect_args={'check_same_thread': False}, connect_args={'check_same_thread': False},
poolclass=StaticPool) poolclass=StaticPool,
pool_reset_on_return=None)
else: else:
self.engine = create_engine(self.db_url, echo=False) self.engine = create_engine(self.db_url, echo=False)
models.Base.metadata.create_all(self.engine) models.Base.metadata.create_all(self.engine)
session_factory = sessionmaker(bind=self.engine) session_factory = sessionmaker(bind=self.engine)
Session = scoped_session(session_factory) _SESSION = scoped_session(session_factory)
self._migrate_schema() self._migrate_schema()
self.db_ready.set() self.db_ready.set()
def _migrate_schema(self): def _migrate_schema(self):
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
import homeassistant.components.recorder.models as models from homeassistant.components.recorder.models import SCHEMA_VERSION
schema_changes = models.SchemaChanges schema_changes = get_model('SchemaChanges')
current_version = getattr(Session.query(schema_changes).order_by( with session_scope() as session:
schema_changes.change_id.desc()).first(), 'schema_version', None) res = session.query(schema_changes).order_by(
schema_changes.change_id.desc()).first()
current_version = getattr(res, 'schema_version', None)
if current_version == models.SCHEMA_VERSION: if current_version == SCHEMA_VERSION:
return return
_LOGGER.debug("Schema version incorrect: %d", current_version) _LOGGER.debug("Schema version incorrect: %s", current_version)
if current_version is None: if current_version is None:
current_version = self._inspect_schema_version() current_version = self._inspect_schema_version()
_LOGGER.debug("No schema version found. Inspected version: %d", _LOGGER.debug("No schema version found. Inspected version: %s",
current_version) current_version)
for version in range(current_version, models.SCHEMA_VERSION): for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1 new_version = version + 1
_LOGGER.info( _LOGGER.info("Upgrading recorder db schema to version %s",
"Upgrading recorder db schema to version %d", new_version) new_version)
self._apply_update(new_version) self._apply_update(new_version)
self._commit(schema_changes(schema_version=new_version)) self._commit(session,
_LOGGER.info( schema_changes(schema_version=new_version))
"Upgraded recorder db schema to version %d", new_version) _LOGGER.info("Upgraded recorder db schema to version %s",
new_version)
def _apply_update(self, new_version): def _apply_update(self, new_version):
"""Perform operations to bring schema up to date.""" """Perform operations to bring schema up to date."""
@ -368,51 +373,54 @@ class Recorder(threading.Thread):
import homeassistant.components.recorder.models as models import homeassistant.components.recorder.models as models
inspector = reflection.Inspector.from_engine(self.engine) inspector = reflection.Inspector.from_engine(self.engine)
indexes = inspector.get_indexes("events") indexes = inspector.get_indexes("events")
for index in indexes: with session_scope() as session:
if index['column_names'] == ["time_fired"]: for index in indexes:
# Schema addition from version 1 detected. This is a new db. if index['column_names'] == ["time_fired"]:
current_version = models.SchemaChanges( # Schema addition from version 1 detected. New DB.
schema_version=models.SCHEMA_VERSION) current_version = models.SchemaChanges(
self._commit(current_version) schema_version=models.SCHEMA_VERSION)
return models.SCHEMA_VERSION self._commit(session, current_version)
return models.SCHEMA_VERSION
# Version 1 schema changes not found, this db needs to be migrated. # Version 1 schema changes not found, this db needs to be migrated.
current_version = models.SchemaChanges(schema_version=0) current_version = models.SchemaChanges(schema_version=0)
self._commit(current_version) self._commit(session, current_version)
return current_version.schema_version return current_version.schema_version
def _close_connection(self): def _close_connection(self):
"""Close the connection.""" """Close the connection."""
global Session # pylint: disable=global-statement global _SESSION # pylint: disable=invalid-name,global-statement
self.engine.dispose() self.engine.dispose()
self.engine = None self.engine = None
Session = None _SESSION = None
def _setup_run(self): def _setup_run(self):
"""Log the start of the current run.""" """Log the start of the current run."""
recorder_runs = get_model('RecorderRuns') recorder_runs = get_model('RecorderRuns')
for run in query('RecorderRuns').filter_by(end=None): with session_scope() as session:
run.closed_incorrect = True for run in query('RecorderRuns').filter_by(end=None):
run.end = self.recording_start run.closed_incorrect = True
_LOGGER.warning("Ended unfinished session (id=%s from %s)", run.end = self.recording_start
run.run_id, run.start) _LOGGER.warning("Ended unfinished session (id=%s from %s)",
Session.add(run) run.run_id, run.start)
session.add(run)
_LOGGER.warning("Found unfinished sessions") _LOGGER.warning("Found unfinished sessions")
self._run = recorder_runs( self._run = recorder_runs(
start=self.recording_start, start=self.recording_start,
created=dt_util.utcnow() created=dt_util.utcnow()
) )
self._commit(self._run) self._commit(session, self._run)
def _close_run(self): def _close_run(self):
"""Save end time for current run.""" """Save end time for current run."""
self._run.end = dt_util.utcnow() self._run.end = dt_util.utcnow()
self._commit(self._run) with session_scope() as session:
self._commit(session, self._run)
self._run = None self._run = None
def _purge_old_data(self): def _purge_old_data(self, _=None):
"""Purge events and states older than purge_days ago.""" """Purge events and states older than purge_days ago."""
from homeassistant.components.recorder.models import Events, States from homeassistant.components.recorder.models import Events, States
@ -429,8 +437,9 @@ class Recorder(threading.Thread):
.delete(synchronize_session=False) .delete(synchronize_session=False)
_LOGGER.debug("Deleted %s states", deleted_rows) _LOGGER.debug("Deleted %s states", deleted_rows)
if self._commit(_purge_states): with session_scope() as session:
_LOGGER.info("Purged states created before %s", purge_before) if self._commit(session, _purge_states):
_LOGGER.info("Purged states created before %s", purge_before)
def _purge_events(session): def _purge_events(session):
deleted_rows = session.query(Events) \ deleted_rows = session.query(Events) \
@ -438,10 +447,9 @@ class Recorder(threading.Thread):
.delete(synchronize_session=False) .delete(synchronize_session=False)
_LOGGER.debug("Deleted %s events", deleted_rows) _LOGGER.debug("Deleted %s events", deleted_rows)
if self._commit(_purge_events): with session_scope() as session:
_LOGGER.info("Purged events created before %s", purge_before) if self._commit(session, _purge_events):
_LOGGER.info("Purged events created before %s", purge_before)
Session.expire_all()
# Execute sqlite vacuum command to free up space on disk # Execute sqlite vacuum command to free up space on disk
if self.engine.driver == 'sqlite': if self.engine.driver == 'sqlite':
@ -449,10 +457,9 @@ class Recorder(threading.Thread):
self.engine.execute("VACUUM") self.engine.execute("VACUUM")
@staticmethod @staticmethod
def _commit(work): def _commit(session, work):
"""Commit & retry work: Either a model or in a function.""" """Commit & retry work: Either a model or in a function."""
import sqlalchemy.exc import sqlalchemy.exc
session = Session()
for _ in range(0, RETRIES): for _ in range(0, RETRIES):
try: try:
if callable(work): if callable(work):
@ -461,8 +468,10 @@ class Recorder(threading.Thread):
session.add(work) session.add(work)
session.commit() session.commit()
return True return True
except sqlalchemy.exc.OperationalError as e: except sqlalchemy.exc.OperationalError as err:
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True) _LOGGER.error(ERROR_QUERY, err)
session.rollback()
time.sleep(QUERY_RETRY_WAIT)
return False return False

View file

@ -3,7 +3,7 @@
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
import unittest import unittest
from unittest.mock import patch, call from unittest.mock import patch, call, MagicMock
import pytest import pytest
from homeassistant.core import callback from homeassistant.core import callback
@ -24,7 +24,6 @@ class TestRecorder(unittest.TestCase):
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}}) recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
self.hass.start() self.hass.start()
recorder._verify_instance() recorder._verify_instance()
self.session = recorder.Session()
recorder._INSTANCE.block_till_done() recorder._INSTANCE.block_till_done()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
@ -42,26 +41,25 @@ class TestRecorder(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
recorder._INSTANCE.block_till_done() recorder._INSTANCE.block_till_done()
for event_id in range(5): with recorder.session_scope() as session:
if event_id < 3: for event_id in range(5):
timestamp = five_days_ago if event_id < 3:
state = 'purgeme' timestamp = five_days_ago
else: state = 'purgeme'
timestamp = now else:
state = 'dontpurgeme' timestamp = now
state = 'dontpurgeme'
self.session.add(recorder.get_model('States')( session.add(recorder.get_model('States')(
entity_id='test.recorder2', entity_id='test.recorder2',
domain='sensor', domain='sensor',
state=state, state=state,
attributes=json.dumps(attributes), attributes=json.dumps(attributes),
last_changed=timestamp, last_changed=timestamp,
last_updated=timestamp, last_updated=timestamp,
created=timestamp, created=timestamp,
event_id=event_id + 1000 event_id=event_id + 1000
)) ))
self.session.commit()
def _add_test_events(self): def _add_test_events(self):
"""Add a few events for testing.""" """Add a few events for testing."""
@ -71,21 +69,23 @@ class TestRecorder(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
recorder._INSTANCE.block_till_done() recorder._INSTANCE.block_till_done()
for event_id in range(5):
if event_id < 2:
timestamp = five_days_ago
event_type = 'EVENT_TEST_PURGE'
else:
timestamp = now
event_type = 'EVENT_TEST'
self.session.add(recorder.get_model('Events')( with recorder.session_scope() as session:
event_type=event_type, for event_id in range(5):
event_data=json.dumps(event_data), if event_id < 2:
origin='LOCAL', timestamp = five_days_ago
created=timestamp, event_type = 'EVENT_TEST_PURGE'
time_fired=timestamp, else:
)) timestamp = now
event_type = 'EVENT_TEST'
session.add(recorder.get_model('Events')(
event_type=event_type,
event_data=json.dumps(event_data),
origin='LOCAL',
created=timestamp,
time_fired=timestamp,
))
def test_saving_state(self): def test_saving_state(self):
"""Test saving and restoring a state.""" """Test saving and restoring a state."""
@ -205,14 +205,15 @@ class TestRecorder(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
recorder._INSTANCE._apply_update(-1) recorder._INSTANCE._apply_update(-1)
def test_schema_update_calls(self): def test_schema_update_calls(self): # pylint: disable=no-self-use
"""Test that schema migrations occurr in correct order.""" """Test that schema migrations occurr in correct order."""
test_version = recorder.models.SchemaChanges(schema_version=0) test_version = recorder.models.SchemaChanges(schema_version=0)
self.session.add(test_version) with recorder.session_scope() as session:
with patch.object(recorder._INSTANCE, '_apply_update') as update: session.add(test_version)
recorder._INSTANCE._migrate_schema() with patch.object(recorder._INSTANCE, '_apply_update') as update:
update.assert_has_calls([call(version+1) for version in range( recorder._INSTANCE._migrate_schema()
0, recorder.models.SCHEMA_VERSION)]) update.assert_has_calls([call(version+1) for version in range(
0, recorder.models.SCHEMA_VERSION)])
@pytest.fixture @pytest.fixture
@ -220,7 +221,7 @@ def hass_recorder():
"""HASS fixture with in-memory recorder.""" """HASS fixture with in-memory recorder."""
hass = get_test_home_assistant() hass = get_test_home_assistant()
def setup_recorder(config): def setup_recorder(config={}):
"""Setup with params.""" """Setup with params."""
db_uri = 'sqlite://' # In memory DB db_uri = 'sqlite://' # In memory DB
conf = {recorder.CONF_DB_URL: db_uri} conf = {recorder.CONF_DB_URL: db_uri}
@ -301,3 +302,61 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder):
assert len(states) == 1 assert len(states) == 1
assert hass.states.get('test.ok') == states[0] assert hass.states.get('test.ok') == states[0]
assert hass.states.get('test.ok').state == 'state2' assert hass.states.get('test.ok').state == 'state2'
def test_recorder_errors_exceptions(hass_recorder): \
# pylint: disable=redefined-outer-name
"""Test session_scope and get_model errors."""
# Model cannot be resolved
assert recorder.get_model('dont-exist') is None
# Verify the instance fails before setup
with pytest.raises(RuntimeError):
recorder._verify_instance()
# Setup the recorder
hass_recorder()
recorder._verify_instance()
# Verify session scope raises (and prints) an exception
with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \
pytest.raises(Exception) as err:
with recorder.session_scope() as session:
session.execute('select * from notthere')
assert e_mock.call_count == 1
assert recorder.ERROR_QUERY[:-4] in e_mock.call_args[0][0]
assert 'no such table' in str(err.value)
def test_recorder_bad_commit(hass_recorder):
"""Bad _commit should retry 3 times."""
hass_recorder()
def work(session):
"""Bad work."""
session.execute('select * from notthere')
with patch('homeassistant.components.recorder.time.sleep') as e_mock, \
recorder.session_scope() as session:
res = recorder._INSTANCE._commit(session, work)
assert res is False
assert e_mock.call_count == 3
def test_recorder_bad_execute(hass_recorder):
"""Bad execute, retry 3 times."""
hass_recorder()
def to_native():
"""Rasie exception."""
from sqlalchemy.exc import SQLAlchemyError
raise SQLAlchemyError()
mck1 = MagicMock()
mck1.to_native = to_native
with patch('homeassistant.components.recorder.time.sleep') as e_mock:
res = recorder.execute((mck1,))
assert res == []
assert e_mock.call_count == 3