Allow chaining contexts (#21028)
* Allow chaining contexts * Add stubbed out migration
This commit is contained in:
parent
b39846fb6b
commit
52f337ef00
12 changed files with 88 additions and 39 deletions
|
@ -7,7 +7,7 @@ import logging
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.setup import async_prepare_setup_platform
|
from homeassistant.setup import async_prepare_setup_platform
|
||||||
from homeassistant.core import CoreState
|
from homeassistant.core import CoreState, Context
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
|
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
|
||||||
|
@ -280,15 +280,21 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
if skip_condition or self._cond_func(variables):
|
if not skip_condition and not self._cond_func(variables):
|
||||||
self.async_set_context(context)
|
return
|
||||||
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
|
|
||||||
ATTR_NAME: self._name,
|
# Create a new context referring to the old context.
|
||||||
ATTR_ENTITY_ID: self.entity_id,
|
parent_id = None if context is None else context.id
|
||||||
}, context=context)
|
trigger_context = Context(parent_id=parent_id)
|
||||||
await self._async_action(self.entity_id, variables, context)
|
|
||||||
self._last_triggered = utcnow()
|
self.async_set_context(trigger_context)
|
||||||
await self.async_update_ha_state()
|
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
|
||||||
|
ATTR_NAME: self._name,
|
||||||
|
ATTR_ENTITY_ID: self.entity_id,
|
||||||
|
}, context=trigger_context)
|
||||||
|
await self._async_action(self.entity_id, variables, trigger_context)
|
||||||
|
self._last_triggered = utcnow()
|
||||||
|
await self.async_update_ha_state()
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self):
|
async def async_will_remove_from_hass(self):
|
||||||
"""Remove listeners when removing automation from HASS."""
|
"""Remove listeners when removing automation from HASS."""
|
||||||
|
|
|
@ -220,6 +220,15 @@ def _apply_update(engine, new_version, old_version):
|
||||||
_create_index(engine, "states", "ix_states_context_user_id")
|
_create_index(engine, "states", "ix_states_context_user_id")
|
||||||
elif new_version == 7:
|
elif new_version == 7:
|
||||||
_create_index(engine, "states", "ix_states_entity_id")
|
_create_index(engine, "states", "ix_states_entity_id")
|
||||||
|
elif new_version == 8:
|
||||||
|
# Pending migration, want to group a few.
|
||||||
|
pass
|
||||||
|
# _add_columns(engine, "events", [
|
||||||
|
# 'context_parent_id CHARACTER(36)',
|
||||||
|
# ])
|
||||||
|
# _add_columns(engine, "states", [
|
||||||
|
# 'context_parent_id CHARACTER(36)',
|
||||||
|
# ])
|
||||||
else:
|
else:
|
||||||
raise ValueError("No schema migration defined for version {}"
|
raise ValueError("No schema migration defined for version {}"
|
||||||
.format(new_version))
|
.format(new_version))
|
||||||
|
|
|
@ -34,16 +34,20 @@ class Events(Base): # type: ignore
|
||||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||||
context_id = Column(String(36), index=True)
|
context_id = Column(String(36), index=True)
|
||||||
context_user_id = Column(String(36), index=True)
|
context_user_id = Column(String(36), index=True)
|
||||||
|
# context_parent_id = Column(String(36), index=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_event(event):
|
def from_event(event):
|
||||||
"""Create an event database object from a native event."""
|
"""Create an event database object from a native event."""
|
||||||
return Events(event_type=event.event_type,
|
return Events(
|
||||||
event_data=json.dumps(event.data, cls=JSONEncoder),
|
event_type=event.event_type,
|
||||||
origin=str(event.origin),
|
event_data=json.dumps(event.data, cls=JSONEncoder),
|
||||||
time_fired=event.time_fired,
|
origin=str(event.origin),
|
||||||
context_id=event.context.id,
|
time_fired=event.time_fired,
|
||||||
context_user_id=event.context.user_id)
|
context_id=event.context.id,
|
||||||
|
context_user_id=event.context.user_id,
|
||||||
|
# context_parent_id=event.context.parent_id,
|
||||||
|
)
|
||||||
|
|
||||||
def to_native(self):
|
def to_native(self):
|
||||||
"""Convert to a natve HA Event."""
|
"""Convert to a natve HA Event."""
|
||||||
|
@ -81,6 +85,7 @@ class States(Base): # type: ignore
|
||||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||||
context_id = Column(String(36), index=True)
|
context_id = Column(String(36), index=True)
|
||||||
context_user_id = Column(String(36), index=True)
|
context_user_id = Column(String(36), index=True)
|
||||||
|
# context_parent_id = Column(String(36), index=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
# Used for fetching the state of entities at a specific time
|
# Used for fetching the state of entities at a specific time
|
||||||
|
@ -99,6 +104,7 @@ class States(Base): # type: ignore
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
context_id=event.context.id,
|
context_id=event.context.id,
|
||||||
context_user_id=event.context.user_id,
|
context_user_id=event.context.user_id,
|
||||||
|
# context_parent_id=event.context.parent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# State got deleted
|
# State got deleted
|
||||||
|
|
|
@ -409,6 +409,10 @@ class Context:
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parent_id = attr.ib(
|
||||||
|
type=Optional[str],
|
||||||
|
default=None
|
||||||
|
)
|
||||||
id = attr.ib(
|
id = attr.ib(
|
||||||
type=str,
|
type=str,
|
||||||
default=attr.Factory(lambda: uuid.uuid4().hex),
|
default=attr.Factory(lambda: uuid.uuid4().hex),
|
||||||
|
@ -418,6 +422,7 @@ class Context:
|
||||||
"""Return a dictionary representation of the context."""
|
"""Return a dictionary representation of the context."""
|
||||||
return {
|
return {
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
|
'parent_id': self.parent_id,
|
||||||
'user_id': self.user_id,
|
'user_id': self.user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ async def test_if_fires_on_event(hass, calls):
|
||||||
hass.bus.async_fire('test_event', context=context)
|
hass.bus.async_fire('test_event', context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
|
|
||||||
await common.async_turn_off(hass)
|
await common.async_turn_off(hass)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
|
@ -68,7 +68,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
assert 'geo_location - geo_location.entity - hello - hello - test' == \
|
assert 'geo_location - geo_location.entity - hello - hello - test' == \
|
||||||
calls[0].data['some']
|
calls[0].data['some']
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ async def test_if_fires_on_zone_appear(hass, calls):
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
assert 'geo_location - geo_location.entity - - hello - test' == \
|
assert 'geo_location - geo_location.entity - - hello - test' == \
|
||||||
calls[0].data['some']
|
calls[0].data['some']
|
||||||
|
|
||||||
|
|
|
@ -369,38 +369,47 @@ async def test_shared_context(hass, calls):
|
||||||
})
|
})
|
||||||
|
|
||||||
context = Context()
|
context = Context()
|
||||||
automation_mock = Mock()
|
first_automation_listener = Mock()
|
||||||
event_mock = Mock()
|
event_mock = Mock()
|
||||||
|
|
||||||
hass.bus.async_listen('test_event2', automation_mock)
|
hass.bus.async_listen('test_event2', first_automation_listener)
|
||||||
hass.bus.async_listen(EVENT_AUTOMATION_TRIGGERED, event_mock)
|
hass.bus.async_listen(EVENT_AUTOMATION_TRIGGERED, event_mock)
|
||||||
hass.bus.async_fire('test_event', context=context)
|
hass.bus.async_fire('test_event', context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
# Ensure events was fired
|
# Ensure events was fired
|
||||||
assert automation_mock.call_count == 1
|
assert first_automation_listener.call_count == 1
|
||||||
assert event_mock.call_count == 2
|
assert event_mock.call_count == 2
|
||||||
|
|
||||||
# Ensure context carries through the event
|
# Verify automation triggered evenet for 'hello' automation
|
||||||
args, kwargs = automation_mock.call_args
|
args, kwargs = event_mock.call_args_list[0]
|
||||||
assert args[0].context == context
|
first_trigger_context = args[0].context
|
||||||
|
assert first_trigger_context.parent_id == context.id
|
||||||
|
# Ensure event data has all attributes set
|
||||||
|
assert args[0].data.get(ATTR_NAME) is not None
|
||||||
|
assert args[0].data.get(ATTR_ENTITY_ID) is not None
|
||||||
|
|
||||||
for call in event_mock.call_args_list:
|
# Ensure context set correctly for event fired by 'hello' automation
|
||||||
args, kwargs = call
|
args, kwargs = first_automation_listener.call_args
|
||||||
assert args[0].context == context
|
assert args[0].context is first_trigger_context
|
||||||
# Ensure event data has all attributes set
|
|
||||||
assert args[0].data.get(ATTR_NAME) is not None
|
|
||||||
assert args[0].data.get(ATTR_ENTITY_ID) is not None
|
|
||||||
|
|
||||||
# Ensure the automation state shares the same context
|
# Ensure the 'hello' automation state has the right context
|
||||||
state = hass.states.get('automation.hello')
|
state = hass.states.get('automation.hello')
|
||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.context == context
|
assert state.context is first_trigger_context
|
||||||
|
|
||||||
|
# Verify automation triggered evenet for 'bye' automation
|
||||||
|
args, kwargs = event_mock.call_args_list[1]
|
||||||
|
second_trigger_context = args[0].context
|
||||||
|
assert second_trigger_context.parent_id == first_trigger_context.id
|
||||||
|
# Ensure event data has all attributes set
|
||||||
|
assert args[0].data.get(ATTR_NAME) is not None
|
||||||
|
assert args[0].data.get(ATTR_ENTITY_ID) is not None
|
||||||
|
|
||||||
# Ensure the service call from the second automation
|
# Ensure the service call from the second automation
|
||||||
# shares the same context
|
# shares the same context
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert calls[0].context == context
|
assert calls[0].context is second_trigger_context
|
||||||
|
|
||||||
|
|
||||||
async def test_services(hass, calls):
|
async def test_services(hass, calls):
|
||||||
|
|
|
@ -45,7 +45,7 @@ async def test_if_fires_on_entity_change_below(hass, calls):
|
||||||
hass.states.async_set('test.entity', 9, context=context)
|
hass.states.async_set('test.entity', 9, context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
|
|
||||||
# Set above 12 so the automation will fire again
|
# Set above 12 so the automation will fire again
|
||||||
hass.states.async_set('test.entity', 12)
|
hass.states.async_set('test.entity', 12)
|
||||||
|
@ -134,7 +134,7 @@ async def test_if_not_fires_on_entity_change_below_to_below(hass, calls):
|
||||||
hass.states.async_set('test.entity', 9, context=context)
|
hass.states.async_set('test.entity', 9, context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
|
|
||||||
# already below so should not fire again
|
# already below so should not fire again
|
||||||
hass.states.async_set('test.entity', 5)
|
hass.states.async_set('test.entity', 5)
|
||||||
|
|
|
@ -55,7 +55,7 @@ async def test_if_fires_on_entity_change(hass, calls):
|
||||||
hass.states.async_set('test.entity', 'world', context=context)
|
hass.states.async_set('test.entity', 'world', context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
assert 'state - test.entity - hello - world - None' == \
|
assert 'state - test.entity - hello - world - None' == \
|
||||||
calls[0].data['some']
|
calls[0].data['some']
|
||||||
|
|
||||||
|
|
|
@ -257,7 +257,7 @@ async def test_if_fires_on_change_with_template_advanced(hass, calls):
|
||||||
hass.states.async_set('test.entity', 'world', context=context)
|
hass.states.async_set('test.entity', 'world', context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
assert 'template - test.entity - hello - world' == \
|
assert 'template - test.entity - hello - world' == \
|
||||||
calls[0].data['some']
|
calls[0].data['some']
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert 1 == len(calls)
|
assert 1 == len(calls)
|
||||||
assert calls[0].context is context
|
assert calls[0].context.parent_id == context.id
|
||||||
assert 'zone - test.entity - hello - hello - test' == \
|
assert 'zone - test.entity - hello - hello - test' == \
|
||||||
calls[0].data['some']
|
calls[0].data['some']
|
||||||
|
|
||||||
|
|
|
@ -310,6 +310,7 @@ class TestEvent(unittest.TestCase):
|
||||||
'time_fired': now,
|
'time_fired': now,
|
||||||
'context': {
|
'context': {
|
||||||
'id': event.context.id,
|
'id': event.context.id,
|
||||||
|
'parent_id': None,
|
||||||
'user_id': event.context.user_id,
|
'user_id': event.context.user_id,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1076,3 +1077,16 @@ async def test_service_call_event_contains_original_data(hass):
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert calls[0].data['number'] == 23
|
assert calls[0].data['number'] == 23
|
||||||
assert calls[0].context is context
|
assert calls[0].context is context
|
||||||
|
|
||||||
|
|
||||||
|
def test_context():
|
||||||
|
"""Test context init."""
|
||||||
|
c = ha.Context()
|
||||||
|
assert c.user_id is None
|
||||||
|
assert c.parent_id is None
|
||||||
|
assert c.id is not None
|
||||||
|
|
||||||
|
c = ha.Context(23, 100)
|
||||||
|
assert c.user_id == 23
|
||||||
|
assert c.parent_id == 100
|
||||||
|
assert c.id is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue