Add and restore context in recorder (#15859)
This commit is contained in:
parent
da916d7b27
commit
9512bb9587
7 changed files with 73 additions and 14 deletions
|
@ -114,6 +114,27 @@ def _drop_index(engine, table_name, index_name):
|
|||
"critical operation.", index_name, table_name)
|
||||
|
||||
|
||||
def _add_columns(engine, table_name, columns_def):
|
||||
"""Add columns to a table."""
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
columns_def = ['ADD COLUMN {}'.format(col_def) for col_def in columns_def]
|
||||
|
||||
try:
|
||||
engine.execute(text("ALTER TABLE {table} {columns_def}".format(
|
||||
table=table_name,
|
||||
columns_def=', '.join(columns_def))))
|
||||
return
|
||||
except SQLAlchemyError:
|
||||
pass
|
||||
|
||||
for column_def in columns_def:
|
||||
engine.execute(text("ALTER TABLE {table} {column_def}".format(
|
||||
table=table_name,
|
||||
column_def=column_def)))
|
||||
|
||||
|
||||
def _apply_update(engine, new_version, old_version):
|
||||
"""Perform operations to bring schema up to date."""
|
||||
if new_version == 1:
|
||||
|
@ -146,6 +167,19 @@ def _apply_update(engine, new_version, old_version):
|
|||
elif new_version == 5:
|
||||
# Create supporting index for States.event_id foreign key
|
||||
_create_index(engine, "states", "ix_states_event_id")
|
||||
elif new_version == 6:
|
||||
_add_columns(engine, "events", [
|
||||
'context_id CHARACTER(36)',
|
||||
'context_user_id CHARACTER(36)',
|
||||
])
|
||||
_create_index(engine, "events", "ix_events_context_id")
|
||||
_create_index(engine, "events", "ix_events_context_user_id")
|
||||
_add_columns(engine, "states", [
|
||||
'context_id CHARACTER(36)',
|
||||
'context_user_id CHARACTER(36)',
|
||||
])
|
||||
_create_index(engine, "states", "ix_states_context_id")
|
||||
_create_index(engine, "states", "ix_states_context_user_id")
|
||||
else:
|
||||
raise ValueError("No schema migration defined for version {}"
|
||||
.format(new_version))
|
||||
|
|
|
@ -9,14 +9,15 @@ from sqlalchemy import (
|
|||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.core import Event, EventOrigin, State, split_entity_id
|
||||
from homeassistant.core import (
|
||||
Context, Event, EventOrigin, State, split_entity_id)
|
||||
from homeassistant.remote import JSONEncoder
|
||||
|
||||
# SQLAlchemy Schema
|
||||
# pylint: disable=invalid-name
|
||||
Base = declarative_base()
|
||||
|
||||
SCHEMA_VERSION = 5
|
||||
SCHEMA_VERSION = 6
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,6 +32,8 @@ class Events(Base): # type: ignore
|
|||
origin = Column(String(32))
|
||||
time_fired = Column(DateTime(timezone=True), index=True)
|
||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
context_id = Column(String(36), index=True)
|
||||
context_user_id = Column(String(36), index=True)
|
||||
|
||||
@staticmethod
|
||||
def from_event(event):
|
||||
|
@ -38,16 +41,23 @@ class Events(Base): # type: ignore
|
|||
return Events(event_type=event.event_type,
|
||||
event_data=json.dumps(event.data, cls=JSONEncoder),
|
||||
origin=str(event.origin),
|
||||
time_fired=event.time_fired)
|
||||
time_fired=event.time_fired,
|
||||
context_id=event.context.id,
|
||||
context_user_id=event.context.user_id)
|
||||
|
||||
def to_native(self):
|
||||
"""Convert to a natve HA Event."""
|
||||
context = Context(
|
||||
id=self.context_id,
|
||||
user_id=self.context_user_id
|
||||
)
|
||||
try:
|
||||
return Event(
|
||||
self.event_type,
|
||||
json.loads(self.event_data),
|
||||
EventOrigin(self.origin),
|
||||
_process_timestamp(self.time_fired)
|
||||
_process_timestamp(self.time_fired),
|
||||
context=context,
|
||||
)
|
||||
except ValueError:
|
||||
# When json.loads fails
|
||||
|
@ -69,6 +79,8 @@ class States(Base): # type: ignore
|
|||
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow,
|
||||
index=True)
|
||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
context_id = Column(String(36), index=True)
|
||||
context_user_id = Column(String(36), index=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Used for fetching the state of entities at a specific time
|
||||
|
@ -82,7 +94,11 @@ class States(Base): # type: ignore
|
|||
entity_id = event.data['entity_id']
|
||||
state = event.data.get('new_state')
|
||||
|
||||
dbstate = States(entity_id=entity_id)
|
||||
dbstate = States(
|
||||
entity_id=entity_id,
|
||||
context_id=event.context.id,
|
||||
context_user_id=event.context.user_id,
|
||||
)
|
||||
|
||||
# State got deleted
|
||||
if state is None:
|
||||
|
@ -103,12 +119,17 @@ class States(Base): # type: ignore
|
|||
|
||||
def to_native(self):
|
||||
"""Convert to an HA state object."""
|
||||
context = Context(
|
||||
id=self.context_id,
|
||||
user_id=self.context_user_id
|
||||
)
|
||||
try:
|
||||
return State(
|
||||
self.entity_id, self.state,
|
||||
json.loads(self.attributes),
|
||||
_process_timestamp(self.last_changed),
|
||||
_process_timestamp(self.last_updated)
|
||||
_process_timestamp(self.last_updated),
|
||||
context=context,
|
||||
)
|
||||
except ValueError:
|
||||
# When json.loads fails
|
||||
|
|
|
@ -423,7 +423,8 @@ class Event:
|
|||
self.event_type == other.event_type and
|
||||
self.data == other.data and
|
||||
self.origin == other.origin and
|
||||
self.time_fired == other.time_fired)
|
||||
self.time_fired == other.time_fired and
|
||||
self.context == other.context)
|
||||
|
||||
|
||||
class EventBus:
|
||||
|
@ -695,7 +696,8 @@ class State:
|
|||
return (self.__class__ == other.__class__ and # type: ignore
|
||||
self.entity_id == other.entity_id and
|
||||
self.state == other.state and
|
||||
self.attributes == other.attributes)
|
||||
self.attributes == other.attributes and
|
||||
self.context == other.context)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation of the states."""
|
||||
|
|
|
@ -266,7 +266,7 @@ def mock_state_change_event(hass, new_state, old_state=None):
|
|||
if old_state:
|
||||
event_data['old_state'] = old_state
|
||||
|
||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
|
||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
|
|
@ -60,7 +60,7 @@ class TestStates(unittest.TestCase):
|
|||
'entity_id': 'sensor.temperature',
|
||||
'old_state': None,
|
||||
'new_state': state,
|
||||
})
|
||||
}, context=state.context)
|
||||
assert state == States.from_event(event).to_native()
|
||||
|
||||
def test_from_event_to_delete_state(self):
|
||||
|
|
|
@ -83,9 +83,10 @@ class TestComponentHistory(unittest.TestCase):
|
|||
self.wait_recording_done()
|
||||
|
||||
# Get states returns everything before POINT
|
||||
self.assertEqual(states,
|
||||
sorted(history.get_states(self.hass, future),
|
||||
key=lambda state: state.entity_id))
|
||||
for state1, state2 in zip(
|
||||
states, sorted(history.get_states(self.hass, future),
|
||||
key=lambda state: state.entity_id)):
|
||||
assert state1 == state2
|
||||
|
||||
# Test get_state here because we have a DB setup
|
||||
self.assertEqual(
|
||||
|
|
|
@ -246,8 +246,9 @@ class TestEvent(unittest.TestCase):
|
|||
"""Test events."""
|
||||
now = dt_util.utcnow()
|
||||
data = {'some': 'attr'}
|
||||
context = ha.Context()
|
||||
event1, event2 = [
|
||||
ha.Event('some_type', data, time_fired=now)
|
||||
ha.Event('some_type', data, time_fired=now, context=context)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue