From e0dd5a855870e007cc7c8900d79d73e8d56ce971 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 11 Jul 2016 00:46:56 -0700 Subject: [PATCH] Tweak Recorder --- homeassistant/components/recorder/__init__.py | 62 +++++--- homeassistant/components/recorder/models.py | 71 +++++---- requirements_all.txt | 2 +- tests/__init__.py | 4 + tests/components/mqtt/__init__.py | 1 + tests/components/recorder/__init__.py | 1 + .../test_init.py} | 7 +- tests/components/recorder/test_models.py | 140 ++++++++++++++++++ 8 files changed, 234 insertions(+), 54 deletions(-) create mode 100644 tests/components/mqtt/__init__.py create mode 100644 tests/components/recorder/__init__.py rename tests/components/{test_recorder.py => recorder/test_init.py} (97%) create mode 100644 tests/components/recorder/test_models.py diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 77f70d00000..b52bce47c17 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -23,7 +23,7 @@ from homeassistant.helpers.event import track_point_in_utc_time DOMAIN = "recorder" -REQUIREMENTS = ['sqlalchemy==1.0.13'] +REQUIREMENTS = ['sqlalchemy==1.0.14'] DEFAULT_URL = "sqlite:///{hass_config_path}" DEFAULT_DB_FILE = "home-assistant_v2.db" @@ -164,6 +164,8 @@ class Recorder(threading.Thread): from homeassistant.components.recorder.models import Events, States import sqlalchemy.exc + global _INSTANCE + while True: try: self._setup_connection() @@ -183,6 +185,8 @@ class Recorder(threading.Thread): if event == self.quit_object: self._close_run() + self._close_connection() + _INSTANCE = None self.queue.task_done() return @@ -190,25 +194,34 @@ class Recorder(threading.Thread): self.queue.task_done() continue + session = Session() + dbevent = Events.from_event(event) + session.add(dbevent) + for _ in range(0, RETRIES): try: - event_id = Events.record_event(Session, event) + session.commit() break except sqlalchemy.exc.OperationalError as e: - log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True) + log_error(e, retry_wait=QUERY_RETRY_WAIT, + rollback=True) - if event.event_type == EVENT_STATE_CHANGED: - for _ in range(0, RETRIES): - try: - States.record_state( - Session, - event.data['entity_id'], - event.data.get('new_state'), - event_id) - break - except sqlalchemy.exc.OperationalError as e: - log_error(e, retry_wait=QUERY_RETRY_WAIT, - rollback=True) + if event.event_type != EVENT_STATE_CHANGED: + self.queue.task_done() + continue + + session = Session() + dbstate = States.from_event(event) + + for _ in range(0, RETRIES): + try: + dbstate.event_id = dbevent.event_id + session.add(dbstate) + session.commit() + break + except sqlalchemy.exc.OperationalError as e: + log_error(e, retry_wait=QUERY_RETRY_WAIT, + rollback=True) self.queue.task_done() @@ -219,7 +232,7 @@ class Recorder(threading.Thread): def shutdown(self, event): """Tell the recorder to shut down.""" self.queue.put(self.quit_object) - self.block_till_done() + self.queue.join() def block_till_done(self): """Block till all events processed.""" @@ -253,6 +266,13 @@ class Recorder(threading.Thread): Session = scoped_session(session_factory) self.db_ready.set() + def _close_connection(self): + """Close the connection.""" + global Session + self.engine.dispose() + self.engine = None + Session = None + def _setup_run(self): """Log the start of the current run.""" recorder_runs = get_model('RecorderRuns') @@ -269,14 +289,16 @@ class Recorder(threading.Thread): start=self.recording_start, created=dt_util.utcnow() ) - Session().add(self._run) - Session().commit() + session = Session() + session.add(self._run) + session.commit() def _close_run(self): """Save end time for current run.""" self._run.end = dt_util.utcnow() - Session().add(self._run) - Session().commit() + session = Session() + session.add(self._run) + session.commit() self._run = None def _purge_old_data(self): diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 6f6cc28dbfc..73dcb8fd9a3 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -7,9 +7,11 @@ import logging from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, String, Text, distinct) from sqlalchemy.ext.declarative import declarative_base + import homeassistant.util.dt as dt_util from homeassistant.core import Event, EventOrigin, State from homeassistant.remote import JSONEncoder +from homeassistant.helpers.entity import split_entity_id # SQLAlchemy Schema # pylint: disable=invalid-name @@ -31,17 +33,12 @@ class Events(Base): created = Column(DateTime(timezone=True), default=datetime.utcnow) @staticmethod - def record_event(session, event): + def from_event(event): """Save an event to the database.""" - dbevent = Events(event_type=event.event_type, - event_data=json.dumps(event.data, cls=JSONEncoder), - origin=str(event.origin), - time_fired=event.time_fired) - - session.add(dbevent) - session.commit() - - return dbevent.event_id + return Events(event_type=event.event_type, + event_data=json.dumps(event.data, cls=JSONEncoder), + origin=str(event.origin), + time_fired=event.time_fired) def to_native(self): """Convert to a natve HA Event.""" @@ -50,7 +47,7 @@ class Events(Base): self.event_type, json.loads(self.event_data), EventOrigin(self.origin), - dt_util.UTC.localize(self.time_fired) + _process_timestamp(self.time_fired) ) except ValueError: # When json.loads fails @@ -68,7 +65,6 @@ class States(Base): entity_id = Column(String(64)) state = Column(String(255)) attributes = Column(Text) - origin = Column(String(32)) event_id = Column(Integer, ForeignKey('events.event_id')) last_changed = Column(DateTime(timezone=True), default=datetime.utcnow) last_updated = Column(DateTime(timezone=True), default=datetime.utcnow) @@ -80,19 +76,20 @@ class States(Base): 'domain', 'last_updated', 'entity_id'), ) @staticmethod - def record_state(session, entity_id, state, event_id): - """Save a state to the database.""" - now = dt_util.utcnow() + def from_event(event): + """Create object from a state_changed event.""" + entity_id = event.data['entity_id'] + state = event.data.get('new_state') - dbstate = States(event_id=event_id, entity_id=entity_id) + dbstate = States(entity_id=entity_id) # State got deleted if state is None: dbstate.state = '' - dbstate.domain = '' + dbstate.domain = split_entity_id(entity_id)[0] dbstate.attributes = '{}' - dbstate.last_changed = now - dbstate.last_updated = now + dbstate.last_changed = event.time_fired + dbstate.last_updated = event.time_fired else: dbstate.domain = state.domain dbstate.state = state.state @@ -100,8 +97,7 @@ class States(Base): dbstate.last_changed = state.last_changed dbstate.last_updated = state.last_updated - session().add(dbstate) - session().commit() + return dbstate def to_native(self): """Convert to an HA state object.""" @@ -109,8 +105,8 @@ class States(Base): return State( self.entity_id, self.state, json.loads(self.attributes), - dt_util.UTC.localize(self.last_changed), - dt_util.UTC.localize(self.last_updated) + _process_timestamp(self.last_changed), + _process_timestamp(self.last_updated) ) except ValueError: # When json.loads fails @@ -135,17 +131,32 @@ class RecorderRuns(Base): Specify point_in_time if you want to know which existed at that point in time inside the run. """ - from homeassistant.components.recorder import Session, _verify_instance - _verify_instance() + from sqlalchemy.orm.session import Session - query = Session().query(distinct(States.entity_id)).filter( - States.created >= self.start) + session = Session.object_session(self) - if point_in_time is not None or self.end is not None: - query = query.filter(States.created < point_in_time) + assert session is not None, 'RecorderRuns need to be persisted' - return [row.entity_id for row in query] + query = session.query(distinct(States.entity_id)).filter( + States.last_updated >= self.start) + + if point_in_time is not None: + query = query.filter(States.last_updated < point_in_time) + elif self.end is not None: + query = query.filter(States.last_updated < self.end) + + return [row[0] for row in query] def to_native(self): """Return self, native format is this model.""" return self + + +def _process_timestamp(ts): + """Process a timestamp into datetime object.""" + if ts is None: + return None + elif ts.tzinfo is None: + return dt_util.UTC.localize(ts) + else: + return dt_util.as_utc(ts) diff --git a/requirements_all.txt b/requirements_all.txt index 5e08ee42f67..9fba1a239d9 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -387,7 +387,7 @@ somecomfort==0.2.1 speedtest-cli==0.3.4 # homeassistant.components.recorder -sqlalchemy==1.0.13 +sqlalchemy==1.0.14 # homeassistant.components.http static3==0.7.0 diff --git a/tests/__init__.py b/tests/__init__.py index a931604fdce..2c44763f234 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,9 +1,13 @@ """Setup some common test helper things.""" import functools +import logging from homeassistant import util from homeassistant.util import location +logging.basicConfig() +logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + def test_real(func): """Force a function to require a keyword _test_real to be passed in.""" diff --git a/tests/components/mqtt/__init__.py b/tests/components/mqtt/__init__.py new file mode 100644 index 00000000000..d5906361541 --- /dev/null +++ b/tests/components/mqtt/__init__.py @@ -0,0 +1 @@ +"""Tests for MQTT component.""" diff --git a/tests/components/recorder/__init__.py b/tests/components/recorder/__init__.py new file mode 100644 index 00000000000..fca6a655ba4 --- /dev/null +++ b/tests/components/recorder/__init__.py @@ -0,0 +1 @@ +"""Tests for Recorder component.""" diff --git a/tests/components/test_recorder.py b/tests/components/recorder/test_init.py similarity index 97% rename from tests/components/test_recorder.py rename to tests/components/recorder/test_init.py index 08efaa71bbf..7519443f1e4 100644 --- a/tests/components/test_recorder.py +++ b/tests/components/recorder/test_init.py @@ -30,7 +30,6 @@ class TestRecorder(unittest.TestCase): def tearDown(self): # pylint: disable=invalid-name """Stop everything that was started.""" self.hass.stop() - recorder._INSTANCE.block_till_done() def _add_test_states(self): """Add multiple states to the db for testing.""" @@ -97,8 +96,10 @@ class TestRecorder(unittest.TestCase): self.hass.pool.block_till_done() recorder._INSTANCE.block_till_done() - states = recorder.execute( - recorder.query('States')) + db_states = recorder.query('States') + states = recorder.execute(db_states) + + assert db_states[0].event_id is not None self.assertEqual(1, len(states)) self.assertEqual(self.hass.states.get(entity_id), states[0]) diff --git a/tests/components/recorder/test_models.py b/tests/components/recorder/test_models.py new file mode 100644 index 00000000000..55c3e019f15 --- /dev/null +++ b/tests/components/recorder/test_models.py @@ -0,0 +1,140 @@ +"""The tests for the Recorder component.""" +# pylint: disable=too-many-public-methods,protected-access +import unittest +from datetime import datetime + +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +import homeassistant.core as ha +from homeassistant.const import EVENT_STATE_CHANGED +from homeassistant.util import dt +from homeassistant.components.recorder.models import ( + Base, Events, States, RecorderRuns) + +engine = None +Session = None + + +def setUpModule(): + """Set up a database to use.""" + global engine, Session + + engine = create_engine("sqlite://") + Base.metadata.create_all(engine) + session_factory = sessionmaker(bind=engine) + Session = scoped_session(session_factory) + + +def tearDownModule(): + """Close database.""" + global engine, Session + + engine.dispose() + engine = None + Session = None + + +class TestEvents(unittest.TestCase): + """Test Events model.""" + + def test_from_event(self): + """Test converting event to db event.""" + event = ha.Event('test_event', { + 'some_data': 15 + }) + assert event == Events.from_event(event).to_native() + + +class TestStates(unittest.TestCase): + """Test States model.""" + + def test_from_event(self): + """Test converting event to db state.""" + state = ha.State('sensor.temperature', '18') + event = ha.Event(EVENT_STATE_CHANGED, { + 'entity_id': 'sensor.temperature', + 'old_state': None, + 'new_state': state, + }) + assert state == States.from_event(event).to_native() + + def test_from_event_to_delete_state(self): + """Test converting deleting state event to db state.""" + event = ha.Event(EVENT_STATE_CHANGED, { + 'entity_id': 'sensor.temperature', + 'old_state': ha.State('sensor.temperature', '18'), + 'new_state': None, + }) + db_state = States.from_event(event) + + assert db_state.entity_id == 'sensor.temperature' + assert db_state.domain == 'sensor' + assert db_state.state == '' + assert db_state.last_changed == event.time_fired + assert db_state.last_updated == event.time_fired + + +class TestRecorderRuns(unittest.TestCase): + """Test recorder run model.""" + + def setUp(self): + """Set up recorder runs.""" + self.session = session = Session() + session.query(Events).delete() + session.query(States).delete() + session.query(RecorderRuns).delete() + + def tearDown(self): + """Clean up.""" + self.session.rollback() + + def test_entity_ids(self): + """Test if entity ids helper method works.""" + run = RecorderRuns( + start=datetime(2016, 7, 9, 11, 0, 0, tzinfo=dt.UTC), + end=datetime(2016, 7, 9, 23, 0, 0, tzinfo=dt.UTC), + closed_incorrect=False, + created=datetime(2016, 7, 9, 11, 0, 0, tzinfo=dt.UTC), + ) + + self.session.add(run) + self.session.commit() + + before_run = datetime(2016, 7, 9, 8, 0, 0, tzinfo=dt.UTC) + in_run = datetime(2016, 7, 9, 13, 0, 0, tzinfo=dt.UTC) + in_run2 = datetime(2016, 7, 9, 15, 0, 0, tzinfo=dt.UTC) + in_run3 = datetime(2016, 7, 9, 18, 0, 0, tzinfo=dt.UTC) + after_run = datetime(2016, 7, 9, 23, 30, 0, tzinfo=dt.UTC) + + assert run.to_native() == run + assert run.entity_ids() == [] + + self.session.add(States( + entity_id='sensor.temperature', + state='20', + last_changed=before_run, + last_updated=before_run, + )) + self.session.add(States( + entity_id='sensor.sound', + state='10', + last_changed=after_run, + last_updated=after_run, + )) + + self.session.add(States( + entity_id='sensor.humidity', + state='76', + last_changed=in_run, + last_updated=in_run, + )) + self.session.add(States( + entity_id='sensor.lux', + state='5', + last_changed=in_run3, + last_updated=in_run3, + )) + + assert sorted(run.entity_ids()) == ['sensor.humidity', 'sensor.lux'] + assert run.entity_ids(in_run2) == ['sensor.humidity']