From 9512bb95872a76d02df84fe5c40fcc9461f42878 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 10 Aug 2018 18:09:01 +0200 Subject: [PATCH] Add and restore context in recorder (#15859) --- .../components/recorder/migration.py | 34 +++++++++++++++++++ homeassistant/components/recorder/models.py | 33 ++++++++++++++---- homeassistant/core.py | 6 ++-- tests/common.py | 2 +- tests/components/recorder/test_models.py | 2 +- tests/components/test_history.py | 7 ++-- tests/test_core.py | 3 +- 7 files changed, 73 insertions(+), 14 deletions(-) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index af70c9d998c..939985ebfb1 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -114,6 +114,27 @@ def _drop_index(engine, table_name, index_name): "critical operation.", index_name, table_name) +def _add_columns(engine, table_name, columns_def): + """Add columns to a table.""" + from sqlalchemy import text + from sqlalchemy.exc import SQLAlchemyError + + columns_def = ['ADD COLUMN {}'.format(col_def) for col_def in columns_def] + + try: + engine.execute(text("ALTER TABLE {table} {columns_def}".format( + table=table_name, + columns_def=', '.join(columns_def)))) + return + except SQLAlchemyError: + pass + + for column_def in columns_def: + engine.execute(text("ALTER TABLE {table} {column_def}".format( + table=table_name, + column_def=column_def))) + + def _apply_update(engine, new_version, old_version): """Perform operations to bring schema up to date.""" if new_version == 1: @@ -146,6 +167,19 @@ def _apply_update(engine, new_version, old_version): elif new_version == 5: # Create supporting index for States.event_id foreign key _create_index(engine, "states", "ix_states_event_id") + elif new_version == 6: + _add_columns(engine, "events", [ + 'context_id CHARACTER(36)', + 'context_user_id CHARACTER(36)', + ]) + _create_index(engine, "events", "ix_events_context_id") + _create_index(engine, "events", "ix_events_context_user_id") + _add_columns(engine, "states", [ + 'context_id CHARACTER(36)', + 'context_user_id CHARACTER(36)', + ]) + _create_index(engine, "states", "ix_states_context_id") + _create_index(engine, "states", "ix_states_context_user_id") else: raise ValueError("No schema migration defined for version {}" .format(new_version)) diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index e7948446231..b8b777990f7 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -9,14 +9,15 @@ from sqlalchemy import ( from sqlalchemy.ext.declarative import declarative_base import homeassistant.util.dt as dt_util -from homeassistant.core import Event, EventOrigin, State, split_entity_id +from homeassistant.core import ( + Context, Event, EventOrigin, State, split_entity_id) from homeassistant.remote import JSONEncoder # SQLAlchemy Schema # pylint: disable=invalid-name Base = declarative_base() -SCHEMA_VERSION = 5 +SCHEMA_VERSION = 6 _LOGGER = logging.getLogger(__name__) @@ -31,6 +32,8 @@ class Events(Base): # type: ignore origin = Column(String(32)) time_fired = Column(DateTime(timezone=True), index=True) created = Column(DateTime(timezone=True), default=datetime.utcnow) + context_id = Column(String(36), index=True) + context_user_id = Column(String(36), index=True) @staticmethod def from_event(event): @@ -38,16 +41,23 @@ class Events(Base): # type: ignore return Events(event_type=event.event_type, event_data=json.dumps(event.data, cls=JSONEncoder), origin=str(event.origin), - time_fired=event.time_fired) + time_fired=event.time_fired, + context_id=event.context.id, + context_user_id=event.context.user_id) def to_native(self): """Convert to a natve HA Event.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id + ) try: return Event( self.event_type, json.loads(self.event_data), EventOrigin(self.origin), - _process_timestamp(self.time_fired) + _process_timestamp(self.time_fired), + context=context, ) except ValueError: # When json.loads fails @@ -69,6 +79,8 @@ class States(Base): # type: ignore last_updated = Column(DateTime(timezone=True), default=datetime.utcnow, index=True) created = Column(DateTime(timezone=True), default=datetime.utcnow) + context_id = Column(String(36), index=True) + context_user_id = Column(String(36), index=True) __table_args__ = ( # Used for fetching the state of entities at a specific time @@ -82,7 +94,11 @@ class States(Base): # type: ignore entity_id = event.data['entity_id'] state = event.data.get('new_state') - dbstate = States(entity_id=entity_id) + dbstate = States( + entity_id=entity_id, + context_id=event.context.id, + context_user_id=event.context.user_id, + ) # State got deleted if state is None: @@ -103,12 +119,17 @@ class States(Base): # type: ignore def to_native(self): """Convert to an HA state object.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id + ) try: return State( self.entity_id, self.state, json.loads(self.attributes), _process_timestamp(self.last_changed), - _process_timestamp(self.last_updated) + _process_timestamp(self.last_updated), + context=context, ) except ValueError: # When json.loads fails diff --git a/homeassistant/core.py b/homeassistant/core.py index b17df2c11fe..cc027c6f5d0 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -423,7 +423,8 @@ class Event: self.event_type == other.event_type and self.data == other.data and self.origin == other.origin and - self.time_fired == other.time_fired) + self.time_fired == other.time_fired and + self.context == other.context) class EventBus: @@ -695,7 +696,8 @@ class State: return (self.__class__ == other.__class__ and # type: ignore self.entity_id == other.entity_id and self.state == other.state and - self.attributes == other.attributes) + self.attributes == other.attributes and + self.context == other.context) def __repr__(self) -> str: """Return the representation of the states.""" diff --git a/tests/common.py b/tests/common.py index 3a2248d0d50..df333cca735 100644 --- a/tests/common.py +++ b/tests/common.py @@ -266,7 +266,7 @@ def mock_state_change_event(hass, new_state, old_state=None): if old_state: event_data['old_state'] = old_state - hass.bus.fire(EVENT_STATE_CHANGED, event_data) + hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context) @asyncio.coroutine diff --git a/tests/components/recorder/test_models.py b/tests/components/recorder/test_models.py index c616f3d0af1..3d1beb3a642 100644 --- a/tests/components/recorder/test_models.py +++ b/tests/components/recorder/test_models.py @@ -60,7 +60,7 @@ class TestStates(unittest.TestCase): 'entity_id': 'sensor.temperature', 'old_state': None, 'new_state': state, - }) + }, context=state.context) assert state == States.from_event(event).to_native() def test_from_event_to_delete_state(self): diff --git a/tests/components/test_history.py b/tests/components/test_history.py index 70f7152e07f..b348498b07e 100644 --- a/tests/components/test_history.py +++ b/tests/components/test_history.py @@ -83,9 +83,10 @@ class TestComponentHistory(unittest.TestCase): self.wait_recording_done() # Get states returns everything before POINT - self.assertEqual(states, - sorted(history.get_states(self.hass, future), - key=lambda state: state.entity_id)) + for state1, state2 in zip( + states, sorted(history.get_states(self.hass, future), + key=lambda state: state.entity_id)): + assert state1 == state2 # Test get_state here because we have a DB setup self.assertEqual( diff --git a/tests/test_core.py b/tests/test_core.py index 9de801e0bb4..f23bed6bc8a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -246,8 +246,9 @@ class TestEvent(unittest.TestCase): """Test events.""" now = dt_util.utcnow() data = {'some': 'attr'} + context = ha.Context() event1, event2 = [ - ha.Event('some_type', data, time_fired=now) + ha.Event('some_type', data, time_fired=now, context=context) for _ in range(2) ]