From b7722ec4523c3144c5bf98cc9d13311448c2c74c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 30 Jan 2016 15:16:31 -0800 Subject: [PATCH] Allow usage of words domain, service, call_id in service data --- homeassistant/components/mqtt_eventstream.py | 3 +- homeassistant/const.py | 1 + homeassistant/core.py | 32 +++++++++++--------- tests/components/test_mqtt.py | 6 ++-- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/homeassistant/components/mqtt_eventstream.py b/homeassistant/components/mqtt_eventstream.py index 53573205378..e69639572ca 100644 --- a/homeassistant/components/mqtt_eventstream.py +++ b/homeassistant/components/mqtt_eventstream.py @@ -11,6 +11,7 @@ from homeassistant.core import EventOrigin, State from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN from homeassistant.components.mqtt import SERVICE_PUBLISH as MQTT_SVC_PUBLISH from homeassistant.const import ( + ATTR_SERVICE_DATA, MATCH_ALL, EVENT_TIME_CHANGED, EVENT_CALL_SERVICE, @@ -46,7 +47,7 @@ def setup(hass, config): if ( event.data.get('domain') == MQTT_DOMAIN and event.data.get('service') == MQTT_SVC_PUBLISH and - event.data.get('topic') == pub_topic + event.data[ATTR_SERVICE_DATA].get('topic') == pub_topic ): return diff --git a/homeassistant/const.py b/homeassistant/const.py index 4109b32a263..e28c418d9e4 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -67,6 +67,7 @@ ATTR_NOW = "now" # Contains domain, service for a SERVICE_CALL event ATTR_DOMAIN = "domain" ATTR_SERVICE = "service" +ATTR_SERVICE_DATA = "service_data" # Data for a SERVICE_EXECUTED event ATTR_SERVICE_CALL_ID = "service_call_id" diff --git a/homeassistant/core.py b/homeassistant/core.py index 853d09020ce..6f95cedb9a9 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -19,7 +19,7 @@ from homeassistant.const import ( SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED, EVENT_CALL_SERVICE, ATTR_NOW, ATTR_DOMAIN, ATTR_SERVICE, MATCH_ALL, EVENT_SERVICE_EXECUTED, ATTR_SERVICE_CALL_ID, EVENT_SERVICE_REGISTERED, - TEMP_CELCIUS, TEMP_FAHRENHEIT, ATTR_FRIENDLY_NAME) + TEMP_CELCIUS, TEMP_FAHRENHEIT, ATTR_FRIENDLY_NAME, ATTR_SERVICE_DATA) from homeassistant.exceptions import ( HomeAssistantError, InvalidEntityFormatError) import homeassistant.util as util @@ -555,13 +555,14 @@ class Service(object): class ServiceCall(object): """Represents a call to a service.""" - __slots__ = ['domain', 'service', 'data'] + __slots__ = ['domain', 'service', 'data', 'call_id'] - def __init__(self, domain, service, data=None): + def __init__(self, domain, service, data=None, call_id=None): """Initialize a service call.""" self.domain = domain self.service = service self.data = data or {} + self.call_id = call_id def __repr__(self): if self.data: @@ -633,10 +634,13 @@ class ServiceRegistry(object): the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. """ call_id = self._generate_unique_id() - event_data = service_data or {} - event_data[ATTR_DOMAIN] = domain - event_data[ATTR_SERVICE] = service - event_data[ATTR_SERVICE_CALL_ID] = call_id + + event_data = { + ATTR_DOMAIN: domain, + ATTR_SERVICE: service, + ATTR_SERVICE_DATA: service_data, + ATTR_SERVICE_CALL_ID: call_id, + } if blocking: executed_event = threading.Event() @@ -658,15 +662,16 @@ class ServiceRegistry(object): def _event_to_service_call(self, event): """Callback for SERVICE_CALLED events from the event bus.""" - service_data = dict(event.data) - domain = service_data.pop(ATTR_DOMAIN, None) - service = service_data.pop(ATTR_SERVICE, None) + service_data = event.data.get(ATTR_SERVICE_DATA) + domain = event.data.get(ATTR_DOMAIN) + service = event.data.get(ATTR_SERVICE) + call_id = event.data.get(ATTR_SERVICE_CALL_ID) if not self.has_service(domain, service): return service_handler = self._services[domain][service] - service_call = ServiceCall(domain, service, service_data) + service_call = ServiceCall(domain, service, service_data, call_id) # Add a job to the pool that calls _execute_service self._pool.add_job(JobPriority.EVENT_SERVICE, @@ -678,10 +683,9 @@ class ServiceRegistry(object): service, call = service_and_call service(call) - if ATTR_SERVICE_CALL_ID in call.data: + if call.call_id is not None: self._bus.fire( - EVENT_SERVICE_EXECUTED, - {ATTR_SERVICE_CALL_ID: call.data[ATTR_SERVICE_CALL_ID]}) + EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id}) def _generate_unique_id(self): """Generate a unique service call id.""" diff --git a/tests/components/test_mqtt.py b/tests/components/test_mqtt.py index 40e473a3572..c36459e5500 100644 --- a/tests/components/test_mqtt.py +++ b/tests/components/test_mqtt.py @@ -63,8 +63,10 @@ class TestMQTT(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) - self.assertEqual('test-topic', self.calls[0][0].data[mqtt.ATTR_TOPIC]) - self.assertEqual('test-payload', self.calls[0][0].data[mqtt.ATTR_PAYLOAD]) + self.assertEqual('test-topic', + self.calls[0][0].data['service_data'][mqtt.ATTR_TOPIC]) + self.assertEqual('test-payload', + self.calls[0][0].data['service_data'][mqtt.ATTR_PAYLOAD]) def test_service_call_without_topic_does_not_publush(self): self.hass.bus.fire(EVENT_CALL_SERVICE, {