Tweak Recorder

This commit is contained in:
Paulus Schoutsen 2016-07-11 00:46:56 -07:00
parent bde2f0d5a0
commit e0dd5a8558
8 changed files with 234 additions and 54 deletions

View file

@ -23,7 +23,7 @@ from homeassistant.helpers.event import track_point_in_utc_time
DOMAIN = "recorder" DOMAIN = "recorder"
REQUIREMENTS = ['sqlalchemy==1.0.13'] REQUIREMENTS = ['sqlalchemy==1.0.14']
DEFAULT_URL = "sqlite:///{hass_config_path}" DEFAULT_URL = "sqlite:///{hass_config_path}"
DEFAULT_DB_FILE = "home-assistant_v2.db" DEFAULT_DB_FILE = "home-assistant_v2.db"
@ -164,6 +164,8 @@ class Recorder(threading.Thread):
from homeassistant.components.recorder.models import Events, States from homeassistant.components.recorder.models import Events, States
import sqlalchemy.exc import sqlalchemy.exc
global _INSTANCE
while True: while True:
try: try:
self._setup_connection() self._setup_connection()
@ -183,6 +185,8 @@ class Recorder(threading.Thread):
if event == self.quit_object: if event == self.quit_object:
self._close_run() self._close_run()
self._close_connection()
_INSTANCE = None
self.queue.task_done() self.queue.task_done()
return return
@ -190,25 +194,34 @@ class Recorder(threading.Thread):
self.queue.task_done() self.queue.task_done()
continue continue
session = Session()
dbevent = Events.from_event(event)
session.add(dbevent)
for _ in range(0, RETRIES): for _ in range(0, RETRIES):
try: try:
event_id = Events.record_event(Session, event) session.commit()
break break
except sqlalchemy.exc.OperationalError as e: except sqlalchemy.exc.OperationalError as e:
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True) log_error(e, retry_wait=QUERY_RETRY_WAIT,
rollback=True)
if event.event_type == EVENT_STATE_CHANGED: if event.event_type != EVENT_STATE_CHANGED:
for _ in range(0, RETRIES): self.queue.task_done()
try: continue
States.record_state(
Session, session = Session()
event.data['entity_id'], dbstate = States.from_event(event)
event.data.get('new_state'),
event_id) for _ in range(0, RETRIES):
break try:
except sqlalchemy.exc.OperationalError as e: dbstate.event_id = dbevent.event_id
log_error(e, retry_wait=QUERY_RETRY_WAIT, session.add(dbstate)
rollback=True) session.commit()
break
except sqlalchemy.exc.OperationalError as e:
log_error(e, retry_wait=QUERY_RETRY_WAIT,
rollback=True)
self.queue.task_done() self.queue.task_done()
@ -219,7 +232,7 @@ class Recorder(threading.Thread):
def shutdown(self, event): def shutdown(self, event):
"""Tell the recorder to shut down.""" """Tell the recorder to shut down."""
self.queue.put(self.quit_object) self.queue.put(self.quit_object)
self.block_till_done() self.queue.join()
def block_till_done(self): def block_till_done(self):
"""Block till all events processed.""" """Block till all events processed."""
@ -253,6 +266,13 @@ class Recorder(threading.Thread):
Session = scoped_session(session_factory) Session = scoped_session(session_factory)
self.db_ready.set() self.db_ready.set()
def _close_connection(self):
"""Close the connection."""
global Session
self.engine.dispose()
self.engine = None
Session = None
def _setup_run(self): def _setup_run(self):
"""Log the start of the current run.""" """Log the start of the current run."""
recorder_runs = get_model('RecorderRuns') recorder_runs = get_model('RecorderRuns')
@ -269,14 +289,16 @@ class Recorder(threading.Thread):
start=self.recording_start, start=self.recording_start,
created=dt_util.utcnow() created=dt_util.utcnow()
) )
Session().add(self._run) session = Session()
Session().commit() session.add(self._run)
session.commit()
def _close_run(self): def _close_run(self):
"""Save end time for current run.""" """Save end time for current run."""
self._run.end = dt_util.utcnow() self._run.end = dt_util.utcnow()
Session().add(self._run) session = Session()
Session().commit() session.add(self._run)
session.commit()
self._run = None self._run = None
def _purge_old_data(self): def _purge_old_data(self):

View file

@ -7,9 +7,11 @@ import logging
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
String, Text, distinct) String, Text, distinct)
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.core import Event, EventOrigin, State from homeassistant.core import Event, EventOrigin, State
from homeassistant.remote import JSONEncoder from homeassistant.remote import JSONEncoder
from homeassistant.helpers.entity import split_entity_id
# SQLAlchemy Schema # SQLAlchemy Schema
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -31,17 +33,12 @@ class Events(Base):
created = Column(DateTime(timezone=True), default=datetime.utcnow) created = Column(DateTime(timezone=True), default=datetime.utcnow)
@staticmethod @staticmethod
def record_event(session, event): def from_event(event):
"""Save an event to the database.""" """Save an event to the database."""
dbevent = Events(event_type=event.event_type, return Events(event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder), event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin), origin=str(event.origin),
time_fired=event.time_fired) time_fired=event.time_fired)
session.add(dbevent)
session.commit()
return dbevent.event_id
def to_native(self): def to_native(self):
"""Convert to a natve HA Event.""" """Convert to a natve HA Event."""
@ -50,7 +47,7 @@ class Events(Base):
self.event_type, self.event_type,
json.loads(self.event_data), json.loads(self.event_data),
EventOrigin(self.origin), EventOrigin(self.origin),
dt_util.UTC.localize(self.time_fired) _process_timestamp(self.time_fired)
) )
except ValueError: except ValueError:
# When json.loads fails # When json.loads fails
@ -68,7 +65,6 @@ class States(Base):
entity_id = Column(String(64)) entity_id = Column(String(64))
state = Column(String(255)) state = Column(String(255))
attributes = Column(Text) attributes = Column(Text)
origin = Column(String(32))
event_id = Column(Integer, ForeignKey('events.event_id')) event_id = Column(Integer, ForeignKey('events.event_id'))
last_changed = Column(DateTime(timezone=True), default=datetime.utcnow) last_changed = Column(DateTime(timezone=True), default=datetime.utcnow)
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow) last_updated = Column(DateTime(timezone=True), default=datetime.utcnow)
@ -80,19 +76,20 @@ class States(Base):
'domain', 'last_updated', 'entity_id'), ) 'domain', 'last_updated', 'entity_id'), )
@staticmethod @staticmethod
def record_state(session, entity_id, state, event_id): def from_event(event):
"""Save a state to the database.""" """Create object from a state_changed event."""
now = dt_util.utcnow() entity_id = event.data['entity_id']
state = event.data.get('new_state')
dbstate = States(event_id=event_id, entity_id=entity_id) dbstate = States(entity_id=entity_id)
# State got deleted # State got deleted
if state is None: if state is None:
dbstate.state = '' dbstate.state = ''
dbstate.domain = '' dbstate.domain = split_entity_id(entity_id)[0]
dbstate.attributes = '{}' dbstate.attributes = '{}'
dbstate.last_changed = now dbstate.last_changed = event.time_fired
dbstate.last_updated = now dbstate.last_updated = event.time_fired
else: else:
dbstate.domain = state.domain dbstate.domain = state.domain
dbstate.state = state.state dbstate.state = state.state
@ -100,8 +97,7 @@ class States(Base):
dbstate.last_changed = state.last_changed dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated dbstate.last_updated = state.last_updated
session().add(dbstate) return dbstate
session().commit()
def to_native(self): def to_native(self):
"""Convert to an HA state object.""" """Convert to an HA state object."""
@ -109,8 +105,8 @@ class States(Base):
return State( return State(
self.entity_id, self.state, self.entity_id, self.state,
json.loads(self.attributes), json.loads(self.attributes),
dt_util.UTC.localize(self.last_changed), _process_timestamp(self.last_changed),
dt_util.UTC.localize(self.last_updated) _process_timestamp(self.last_updated)
) )
except ValueError: except ValueError:
# When json.loads fails # When json.loads fails
@ -135,17 +131,32 @@ class RecorderRuns(Base):
Specify point_in_time if you want to know which existed at that point Specify point_in_time if you want to know which existed at that point
in time inside the run. in time inside the run.
""" """
from homeassistant.components.recorder import Session, _verify_instance from sqlalchemy.orm.session import Session
_verify_instance()
query = Session().query(distinct(States.entity_id)).filter( session = Session.object_session(self)
States.created >= self.start)
if point_in_time is not None or self.end is not None: assert session is not None, 'RecorderRuns need to be persisted'
query = query.filter(States.created < point_in_time)
return [row.entity_id for row in query] query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self): def to_native(self):
"""Return self, native format is this model.""" """Return self, native format is this model."""
return self return self
def _process_timestamp(ts):
"""Process a timestamp into datetime object."""
if ts is None:
return None
elif ts.tzinfo is None:
return dt_util.UTC.localize(ts)
else:
return dt_util.as_utc(ts)

View file

@ -387,7 +387,7 @@ somecomfort==0.2.1
speedtest-cli==0.3.4 speedtest-cli==0.3.4
# homeassistant.components.recorder # homeassistant.components.recorder
sqlalchemy==1.0.13 sqlalchemy==1.0.14
# homeassistant.components.http # homeassistant.components.http
static3==0.7.0 static3==0.7.0

View file

@ -1,9 +1,13 @@
"""Setup some common test helper things.""" """Setup some common test helper things."""
import functools import functools
import logging
from homeassistant import util from homeassistant import util
from homeassistant.util import location from homeassistant.util import location
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
def test_real(func): def test_real(func):
"""Force a function to require a keyword _test_real to be passed in.""" """Force a function to require a keyword _test_real to be passed in."""

View file

@ -0,0 +1 @@
"""Tests for MQTT component."""

View file

@ -0,0 +1 @@
"""Tests for Recorder component."""

View file

@ -30,7 +30,6 @@ class TestRecorder(unittest.TestCase):
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
recorder._INSTANCE.block_till_done()
def _add_test_states(self): def _add_test_states(self):
"""Add multiple states to the db for testing.""" """Add multiple states to the db for testing."""
@ -97,8 +96,10 @@ class TestRecorder(unittest.TestCase):
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
recorder._INSTANCE.block_till_done() recorder._INSTANCE.block_till_done()
states = recorder.execute( db_states = recorder.query('States')
recorder.query('States')) states = recorder.execute(db_states)
assert db_states[0].event_id is not None
self.assertEqual(1, len(states)) self.assertEqual(1, len(states))
self.assertEqual(self.hass.states.get(entity_id), states[0]) self.assertEqual(self.hass.states.get(entity_id), states[0])

View file

@ -0,0 +1,140 @@
"""The tests for the Recorder component."""
# pylint: disable=too-many-public-methods,protected-access
import unittest
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
import homeassistant.core as ha
from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.util import dt
from homeassistant.components.recorder.models import (
Base, Events, States, RecorderRuns)
engine = None
Session = None
def setUpModule():
"""Set up a database to use."""
global engine, Session
engine = create_engine("sqlite://")
Base.metadata.create_all(engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
def tearDownModule():
"""Close database."""
global engine, Session
engine.dispose()
engine = None
Session = None
class TestEvents(unittest.TestCase):
"""Test Events model."""
def test_from_event(self):
"""Test converting event to db event."""
event = ha.Event('test_event', {
'some_data': 15
})
assert event == Events.from_event(event).to_native()
class TestStates(unittest.TestCase):
"""Test States model."""
def test_from_event(self):
"""Test converting event to db state."""
state = ha.State('sensor.temperature', '18')
event = ha.Event(EVENT_STATE_CHANGED, {
'entity_id': 'sensor.temperature',
'old_state': None,
'new_state': state,
})
assert state == States.from_event(event).to_native()
def test_from_event_to_delete_state(self):
"""Test converting deleting state event to db state."""
event = ha.Event(EVENT_STATE_CHANGED, {
'entity_id': 'sensor.temperature',
'old_state': ha.State('sensor.temperature', '18'),
'new_state': None,
})
db_state = States.from_event(event)
assert db_state.entity_id == 'sensor.temperature'
assert db_state.domain == 'sensor'
assert db_state.state == ''
assert db_state.last_changed == event.time_fired
assert db_state.last_updated == event.time_fired
class TestRecorderRuns(unittest.TestCase):
"""Test recorder run model."""
def setUp(self):
"""Set up recorder runs."""
self.session = session = Session()
session.query(Events).delete()
session.query(States).delete()
session.query(RecorderRuns).delete()
def tearDown(self):
"""Clean up."""
self.session.rollback()
def test_entity_ids(self):
"""Test if entity ids helper method works."""
run = RecorderRuns(
start=datetime(2016, 7, 9, 11, 0, 0, tzinfo=dt.UTC),
end=datetime(2016, 7, 9, 23, 0, 0, tzinfo=dt.UTC),
closed_incorrect=False,
created=datetime(2016, 7, 9, 11, 0, 0, tzinfo=dt.UTC),
)
self.session.add(run)
self.session.commit()
before_run = datetime(2016, 7, 9, 8, 0, 0, tzinfo=dt.UTC)
in_run = datetime(2016, 7, 9, 13, 0, 0, tzinfo=dt.UTC)
in_run2 = datetime(2016, 7, 9, 15, 0, 0, tzinfo=dt.UTC)
in_run3 = datetime(2016, 7, 9, 18, 0, 0, tzinfo=dt.UTC)
after_run = datetime(2016, 7, 9, 23, 30, 0, tzinfo=dt.UTC)
assert run.to_native() == run
assert run.entity_ids() == []
self.session.add(States(
entity_id='sensor.temperature',
state='20',
last_changed=before_run,
last_updated=before_run,
))
self.session.add(States(
entity_id='sensor.sound',
state='10',
last_changed=after_run,
last_updated=after_run,
))
self.session.add(States(
entity_id='sensor.humidity',
state='76',
last_changed=in_run,
last_updated=in_run,
))
self.session.add(States(
entity_id='sensor.lux',
state='5',
last_changed=in_run3,
last_updated=in_run3,
))
assert sorted(run.entity_ids()) == ['sensor.humidity', 'sensor.lux']
assert run.entity_ids(in_run2) == ['sensor.humidity']