Fix invalid state (#16558)
* Fix invalid state * Make slightly more efficient in unsubscribing * Use uuid4"
This commit is contained in:
parent
629c4a0bf5
commit
06af76404f
3 changed files with 45 additions and 7 deletions
|
@ -225,6 +225,9 @@ ATTR_ID = 'id'
|
||||||
# Name
|
# Name
|
||||||
ATTR_NAME = 'name'
|
ATTR_NAME = 'name'
|
||||||
|
|
||||||
|
# Data for a SERVICE_EXECUTED event
|
||||||
|
ATTR_SERVICE_CALL_ID = 'service_call_id'
|
||||||
|
|
||||||
# Contains one string or a list of strings, each being an entity id
|
# Contains one string or a list of strings, each being an entity id
|
||||||
ATTR_ENTITY_ID = 'entity_id'
|
ATTR_ENTITY_ID = 'entity_id'
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
|
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
|
||||||
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
||||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||||
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
||||||
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
|
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
|
||||||
|
@ -1042,10 +1042,12 @@ class ServiceRegistry:
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
context = context or Context()
|
context = context or Context()
|
||||||
|
call_id = uuid.uuid4().hex
|
||||||
event_data = {
|
event_data = {
|
||||||
ATTR_DOMAIN: domain.lower(),
|
ATTR_DOMAIN: domain.lower(),
|
||||||
ATTR_SERVICE: service.lower(),
|
ATTR_SERVICE: service.lower(),
|
||||||
ATTR_SERVICE_DATA: service_data,
|
ATTR_SERVICE_DATA: service_data,
|
||||||
|
ATTR_SERVICE_CALL_ID: call_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not blocking:
|
if not blocking:
|
||||||
|
@ -1058,8 +1060,9 @@ class ServiceRegistry:
|
||||||
@callback
|
@callback
|
||||||
def service_executed(event: Event) -> None:
|
def service_executed(event: Event) -> None:
|
||||||
"""Handle an executed service."""
|
"""Handle an executed service."""
|
||||||
if event.context == context:
|
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||||
fut.set_result(True)
|
fut.set_result(True)
|
||||||
|
unsub()
|
||||||
|
|
||||||
unsub = self._hass.bus.async_listen(
|
unsub = self._hass.bus.async_listen(
|
||||||
EVENT_SERVICE_EXECUTED, service_executed)
|
EVENT_SERVICE_EXECUTED, service_executed)
|
||||||
|
@ -1069,7 +1072,8 @@ class ServiceRegistry:
|
||||||
|
|
||||||
done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
|
done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
|
||||||
success = bool(done)
|
success = bool(done)
|
||||||
unsub()
|
if not success:
|
||||||
|
unsub()
|
||||||
return success
|
return success
|
||||||
|
|
||||||
async def _event_to_service_call(self, event: Event) -> None:
|
async def _event_to_service_call(self, event: Event) -> None:
|
||||||
|
@ -1077,6 +1081,7 @@ class ServiceRegistry:
|
||||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||||
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
||||||
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
||||||
|
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||||
|
|
||||||
if not self.has_service(domain, service):
|
if not self.has_service(domain, service):
|
||||||
if event.origin == EventOrigin.local:
|
if event.origin == EventOrigin.local:
|
||||||
|
@ -1088,12 +1093,17 @@ class ServiceRegistry:
|
||||||
|
|
||||||
def fire_service_executed() -> None:
|
def fire_service_executed() -> None:
|
||||||
"""Fire service executed event."""
|
"""Fire service executed event."""
|
||||||
|
if not call_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||||
|
|
||||||
if (service_handler.is_coroutinefunction or
|
if (service_handler.is_coroutinefunction or
|
||||||
service_handler.is_callback):
|
service_handler.is_callback):
|
||||||
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {},
|
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data,
|
||||||
EventOrigin.local, event.context)
|
EventOrigin.local, event.context)
|
||||||
else:
|
else:
|
||||||
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {},
|
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data,
|
||||||
EventOrigin.local, event.context)
|
EventOrigin.local, event.context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -20,9 +20,10 @@ from homeassistant.util.unit_system import (METRIC_SYSTEM)
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
|
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
|
||||||
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
||||||
EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED)
|
EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED,
|
||||||
|
EVENT_SERVICE_EXECUTED)
|
||||||
|
|
||||||
from tests.common import get_test_home_assistant
|
from tests.common import get_test_home_assistant, async_mock_service
|
||||||
|
|
||||||
PST = pytz.timezone('America/Los_Angeles')
|
PST = pytz.timezone('America/Los_Angeles')
|
||||||
|
|
||||||
|
@ -969,3 +970,27 @@ def test_track_task_functions(loop):
|
||||||
assert hass._track_task
|
assert hass._track_task
|
||||||
finally:
|
finally:
|
||||||
yield from hass.async_stop()
|
yield from hass.async_stop()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_service_executed_with_subservices(hass):
|
||||||
|
"""Test we block correctly till all services done."""
|
||||||
|
calls = async_mock_service(hass, 'test', 'inner')
|
||||||
|
|
||||||
|
async def handle_outer(call):
|
||||||
|
"""Handle outer service call."""
|
||||||
|
calls.append(call)
|
||||||
|
call1 = hass.services.async_call('test', 'inner', blocking=True,
|
||||||
|
context=call.context)
|
||||||
|
call2 = hass.services.async_call('test', 'inner', blocking=True,
|
||||||
|
context=call.context)
|
||||||
|
await asyncio.wait([call1, call2])
|
||||||
|
calls.append(call)
|
||||||
|
|
||||||
|
hass.services.async_register('test', 'outer', handle_outer)
|
||||||
|
|
||||||
|
await hass.services.async_call('test', 'outer', blocking=True)
|
||||||
|
|
||||||
|
assert len(calls) == 4
|
||||||
|
assert [call.service for call in calls] == [
|
||||||
|
'outer', 'inner', 'inner', 'outer']
|
||||||
|
assert len(hass.bus.async_listeners().get(EVENT_SERVICE_EXECUTED, [])) == 0
|
||||||
|
|
Loading…
Add table
Reference in a new issue