From 5957e4b75b7e629ef22952b1ffd47985c731bb5e Mon Sep 17 00:00:00 2001 From: emontnemery Date: Wed, 13 Mar 2019 20:58:20 +0100 Subject: [PATCH] Pass Message object to MQTT message callbacks (#21959) * Pass Message object to MQTT message callbacks * Improve method of detecting deprecated msg callback * Fix mysensors * Fixup * Review comments * Fix merge error --- homeassistant/components/mqtt/__init__.py | 99 +++++++++++++------ .../components/mqtt/alarm_control_panel.py | 17 ++-- .../components/mqtt/binary_sensor.py | 3 +- homeassistant/components/mqtt/camera.py | 4 +- homeassistant/components/mqtt/climate.py | 24 +++-- homeassistant/components/mqtt/cover.py | 14 +-- .../components/mqtt/device_tracker.py | 4 +- homeassistant/components/mqtt/discovery.py | 4 +- homeassistant/components/mqtt/fan.py | 12 +-- .../components/mqtt/light/schema_basic.py | 51 +++++----- .../components/mqtt/light/schema_json.py | 4 +- .../components/mqtt/light/schema_template.py | 18 ++-- homeassistant/components/mqtt/lock.py | 3 +- homeassistant/components/mqtt/sensor.py | 3 +- homeassistant/components/mqtt/switch.py | 3 +- homeassistant/components/mqtt/vacuum.py | 26 ++--- homeassistant/components/mysensors/gateway.py | 4 +- tests/components/mqtt/test_init.py | 40 ++++---- tests/components/mqtt/test_subscription.py | 12 +-- 19 files changed, 203 insertions(+), 142 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index ed671a2f8ce..e4d468e2155 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -5,6 +5,8 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ import asyncio +import inspect +from functools import partial, wraps from itertools import groupby import json import logging @@ -264,7 +266,19 @@ MQTT_PUBLISH_SCHEMA = vol.Schema({ # pylint: disable=invalid-name PublishPayloadType = Union[str, bytes, int, float, None] SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None -MessageCallbackType = Callable[[str, SubscribePayloadType, int], None] + + +@attr.s(slots=True, frozen=True) +class Message: + """MQTT Message.""" + + topic = attr.ib(type=str) + payload = attr.ib(type=PublishPayloadType) + qos = attr.ib(type=int) + retain = attr.ib(type=bool) + + +MessageCallbackType = Callable[[Message], None] def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType: @@ -304,6 +318,30 @@ def publish_template(hass: HomeAssistantType, topic, payload_template, hass.services.call(DOMAIN, SERVICE_PUBLISH, data) +def wrap_msg_callback( + msg_callback: MessageCallbackType) -> MessageCallbackType: + """Wrap an MQTT message callback to support deprecated signature.""" + # Check for partials to properly determine if coroutine function + check_func = msg_callback + while isinstance(check_func, partial): + check_func = check_func.func + + wrapper_func = None + if asyncio.iscoroutinefunction(check_func): + @wraps(msg_callback) + async def async_wrapper(msg: Any) -> None: + """Catch and log exception.""" + await msg_callback(msg.topic, msg.payload, msg.qos) + wrapper_func = async_wrapper + else: + @wraps(msg_callback) + def wrapper(msg: Any) -> None: + """Catch and log exception.""" + msg_callback(msg.topic, msg.payload, msg.qos) + wrapper_func = wrapper + return wrapper_func + + @bind_hass async def async_subscribe(hass: HomeAssistantType, topic: str, msg_callback: MessageCallbackType, @@ -313,11 +351,25 @@ async def async_subscribe(hass: HomeAssistantType, topic: str, Call the return value to unsubscribe. """ + # Count callback parameters which don't have a default value + non_default = 0 + if msg_callback: + non_default = sum(p.default == inspect.Parameter.empty for _, p in + inspect.signature(msg_callback).parameters.items()) + + wrapped_msg_callback = msg_callback + # If we have 3 paramaters with no default value, wrap the callback + if non_default == 3: + _LOGGER.info( + "Signature of MQTT msg_callback '%s.%s' is deprecated", + inspect.getmodule(msg_callback).__name__, msg_callback.__name__) + wrapped_msg_callback = wrap_msg_callback(msg_callback) + async_remove = await hass.data[DATA_MQTT].async_subscribe( topic, catch_log_exception( - msg_callback, lambda topic, msg, qos: + wrapped_msg_callback, lambda msg: "Exception in {} when handling msg on '{}': '{}'".format( - msg_callback.__name__, topic, msg)), + msg_callback.__name__, msg.topic, msg.payload)), qos, encoding) return async_remove @@ -575,16 +627,6 @@ class Subscription: encoding = attr.ib(type=str, default='utf-8') -@attr.s(slots=True, frozen=True) -class Message: - """MQTT Message.""" - - topic = attr.ib(type=str) - payload = attr.ib(type=PublishPayloadType) - qos = attr.ib(type=int, default=0) - retain = attr.ib(type=bool, default=False) - - class MQTT: """Home Assistant MQTT client.""" @@ -770,7 +812,8 @@ class MQTT: @callback def _mqtt_handle_message(self, msg) -> None: - _LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload) + _LOGGER.debug("Received message on %s%s: %s", msg.topic, + " (retained)" if msg.retain else "", msg.payload) for subscription in self.subscriptions: if not _match_topic(subscription.topic, msg.topic): @@ -787,7 +830,8 @@ class MQTT: continue self.hass.async_run_job( - subscription.callback, msg.topic, payload, msg.qos) + subscription.callback, Message(msg.topic, payload, msg.qos, + msg.retain)) def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: """Disconnected callback.""" @@ -865,11 +909,9 @@ class MqttAttributes(Entity): from .subscription import async_subscribe_topics @callback - def attributes_message_received(topic: str, - payload: SubscribePayloadType, - qos: int) -> None: + def attributes_message_received(msg: Message) -> None: try: - json_dict = json.loads(payload) + json_dict = json.loads(msg.payload) if isinstance(json_dict, dict): self._attributes = json_dict self.async_write_ha_state() @@ -877,7 +919,7 @@ class MqttAttributes(Entity): _LOGGER.warning("JSON result was not a dictionary") self._attributes = None except ValueError: - _LOGGER.warning("Erroneous JSON: %s", payload) + _LOGGER.warning("Erroneous JSON: %s", msg.payload) self._attributes = None self._attributes_sub_state = await async_subscribe_topics( @@ -927,13 +969,11 @@ class MqttAvailability(Entity): from .subscription import async_subscribe_topics @callback - def availability_message_received(topic: str, - payload: SubscribePayloadType, - qos: int) -> None: + def availability_message_received(msg: Message) -> None: """Handle a new received MQTT availability message.""" - if payload == self._avail_config[CONF_PAYLOAD_AVAILABLE]: + if msg.payload == self._avail_config[CONF_PAYLOAD_AVAILABLE]: self._available = True - elif payload == self._avail_config[CONF_PAYLOAD_NOT_AVAILABLE]: + elif msg.payload == self._avail_config[CONF_PAYLOAD_NOT_AVAILABLE]: self._available = False self.async_write_ha_state() @@ -1064,12 +1104,13 @@ async def websocket_subscribe(hass, connection, msg): if not connection.user.is_admin: raise Unauthorized - async def forward_messages(topic: str, payload: str, qos: int): + async def forward_messages(mqttmsg: Message): """Forward events to websocket.""" connection.send_message(websocket_api.event_message(msg['id'], { - 'topic': topic, - 'payload': payload, - 'qos': qos, + 'topic': mqttmsg.topic, + 'payload': mqttmsg.payload, + 'qos': mqttmsg.qos, + 'retain': mqttmsg.retain, })) connection.subscriptions[msg['id']] = await async_subscribe( diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index a03716676cd..c350b32b4ff 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -126,16 +126,17 @@ class MqttAlarm(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, async def _subscribe_topics(self): """(Re)Subscribe to topics.""" @callback - def message_received(topic, payload, qos): + def message_received(msg): """Run when new MQTT message has been received.""" - if payload not in (STATE_ALARM_DISARMED, STATE_ALARM_ARMED_HOME, - STATE_ALARM_ARMED_AWAY, - STATE_ALARM_ARMED_NIGHT, - STATE_ALARM_PENDING, - STATE_ALARM_TRIGGERED): - _LOGGER.warning("Received unexpected payload: %s", payload) + if msg.payload not in ( + STATE_ALARM_DISARMED, STATE_ALARM_ARMED_HOME, + STATE_ALARM_ARMED_AWAY, + STATE_ALARM_ARMED_NIGHT, + STATE_ALARM_PENDING, + STATE_ALARM_TRIGGERED): + _LOGGER.warning("Received unexpected payload: %s", msg.payload) return - self._state = payload + self._state = msg.payload self.async_write_ha_state() self._sub_state = await subscription.async_subscribe_topics( diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index 103958376c0..f2a93d06f8e 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -133,8 +133,9 @@ class MqttBinarySensor(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self.async_write_ha_state() @callback - def state_message_received(_topic, payload, _qos): + def state_message_received(msg): """Handle a new received MQTT state message.""" + payload = msg.payload value_template = self._config.get(CONF_VALUE_TEMPLATE) if value_template is not None: payload = value_template.async_render_with_possible_json_value( diff --git a/homeassistant/components/mqtt/camera.py b/homeassistant/components/mqtt/camera.py index b9cdb5bef02..ca41f3c4225 100644 --- a/homeassistant/components/mqtt/camera.py +++ b/homeassistant/components/mqtt/camera.py @@ -102,9 +102,9 @@ class MqttCamera(MqttDiscoveryUpdate, Camera): async def _subscribe_topics(self): """(Re)Subscribe to topics.""" @callback - def message_received(topic, payload, qos): + def message_received(msg): """Handle new MQTT messages.""" - self._last_image = payload + self._last_image = msg.payload self._sub_state = await subscription.async_subscribe_topics( self.hass, self._sub_state, diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index 25f5aa68571..ae847437932 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -288,8 +288,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, qos = self._config.get(CONF_QOS) @callback - def handle_current_temp_received(topic, payload, qos): + def handle_current_temp_received(msg): """Handle current temperature coming via MQTT.""" + payload = msg.payload if CONF_CURRENT_TEMPERATURE_TEMPLATE in self._value_templates: payload =\ self._value_templates[CONF_CURRENT_TEMPERATURE_TEMPLATE].\ @@ -308,8 +309,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_mode_received(topic, payload, qos): + def handle_mode_received(msg): """Handle receiving mode via MQTT.""" + payload = msg.payload if CONF_MODE_STATE_TEMPLATE in self._value_templates: payload = self._value_templates[CONF_MODE_STATE_TEMPLATE].\ async_render_with_possible_json_value(payload) @@ -327,8 +329,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_temperature_received(topic, payload, qos): + def handle_temperature_received(msg): """Handle target temperature coming via MQTT.""" + payload = msg.payload if CONF_TEMPERATURE_STATE_TEMPLATE in self._value_templates: payload = \ self._value_templates[CONF_TEMPERATURE_STATE_TEMPLATE].\ @@ -347,8 +350,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_fan_mode_received(topic, payload, qos): + def handle_fan_mode_received(msg): """Handle receiving fan mode via MQTT.""" + payload = msg.payload if CONF_FAN_MODE_STATE_TEMPLATE in self._value_templates: payload = \ self._value_templates[CONF_FAN_MODE_STATE_TEMPLATE].\ @@ -367,8 +371,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_swing_mode_received(topic, payload, qos): + def handle_swing_mode_received(msg): """Handle receiving swing mode via MQTT.""" + payload = msg.payload if CONF_SWING_MODE_STATE_TEMPLATE in self._value_templates: payload = \ self._value_templates[CONF_SWING_MODE_STATE_TEMPLATE].\ @@ -387,8 +392,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_away_mode_received(topic, payload, qos): + def handle_away_mode_received(msg): """Handle receiving away mode via MQTT.""" + payload = msg.payload payload_on = self._config.get(CONF_PAYLOAD_ON) payload_off = self._config.get(CONF_PAYLOAD_OFF) if CONF_AWAY_MODE_STATE_TEMPLATE in self._value_templates: @@ -416,8 +422,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_aux_mode_received(topic, payload, qos): + def handle_aux_mode_received(msg): """Handle receiving aux mode via MQTT.""" + payload = msg.payload payload_on = self._config.get(CONF_PAYLOAD_ON) payload_off = self._config.get(CONF_PAYLOAD_OFF) if CONF_AUX_STATE_TEMPLATE in self._value_templates: @@ -444,8 +451,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': qos} @callback - def handle_hold_mode_received(topic, payload, qos): + def handle_hold_mode_received(msg): """Handle receiving hold mode via MQTT.""" + payload = msg.payload if CONF_HOLD_STATE_TEMPLATE in self._value_templates: payload = self._value_templates[CONF_HOLD_STATE_TEMPLATE].\ async_render_with_possible_json_value(payload) diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index f4f73c76863..37222cbe868 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -216,19 +216,20 @@ class MqttCover(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, topics = {} @callback - def tilt_updated(topic, payload, qos): + def tilt_updated(msg): """Handle tilt updates.""" - if (payload.isnumeric() and - (self._config.get(CONF_TILT_MIN) <= int(payload) <= + if (msg.payload.isnumeric() and + (self._config.get(CONF_TILT_MIN) <= int(msg.payload) <= self._config.get(CONF_TILT_MAX))): - level = self.find_percentage_in_range(float(payload)) + level = self.find_percentage_in_range(float(msg.payload)) self._tilt_value = level self.async_write_ha_state() @callback - def state_message_received(topic, payload, qos): + def state_message_received(msg): """Handle new MQTT state messages.""" + payload = msg.payload if template is not None: payload = template.async_render_with_possible_json_value( payload) @@ -243,8 +244,9 @@ class MqttCover(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self.async_write_ha_state() @callback - def position_message_received(topic, payload, qos): + def position_message_received(msg): """Handle new MQTT state messages.""" + payload = msg.payload if template is not None: payload = template.async_render_with_possible_json_value( payload) diff --git a/homeassistant/components/mqtt/device_tracker.py b/homeassistant/components/mqtt/device_tracker.py index 06bd6d771a4..bf55d955ce1 100644 --- a/homeassistant/components/mqtt/device_tracker.py +++ b/homeassistant/components/mqtt/device_tracker.py @@ -31,10 +31,10 @@ async def async_setup_scanner(hass, config, async_see, discovery_info=None): for dev_id, topic in devices.items(): @callback - def async_message_received(topic, payload, qos, dev_id=dev_id): + def async_message_received(msg, dev_id=dev_id): """Handle received MQTT message.""" hass.async_create_task( - async_see(dev_id=dev_id, location_name=payload)) + async_see(dev_id=dev_id, location_name=msg.payload)) await mqtt.async_subscribe( hass, topic, async_message_received, qos) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 885c14f609f..745e54d0ed7 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -200,8 +200,10 @@ def clear_discovery_hash(hass, discovery_hash): async def async_start(hass: HomeAssistantType, discovery_topic, hass_config, config_entry=None) -> bool: """Initialize of MQTT Discovery.""" - async def async_device_message_received(topic, payload, qos): + async def async_device_message_received(msg): """Process the received message.""" + payload = msg.payload + topic = msg.topic match = TOPIC_MATCHER.match(topic) if not match: diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index eb1e6e84101..7c9f816eff7 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -212,9 +212,9 @@ class MqttFan(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, templates[key] = tpl.async_render_with_possible_json_value @callback - def state_received(topic, payload, qos): + def state_received(msg): """Handle new received MQTT message.""" - payload = templates[CONF_STATE](payload) + payload = templates[CONF_STATE](msg.payload) if payload == self._payload[STATE_ON]: self._state = True elif payload == self._payload[STATE_OFF]: @@ -228,9 +228,9 @@ class MqttFan(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, 'qos': self._config.get(CONF_QOS)} @callback - def speed_received(topic, payload, qos): + def speed_received(msg): """Handle new received MQTT message for the speed.""" - payload = templates[ATTR_SPEED](payload) + payload = templates[ATTR_SPEED](msg.payload) if payload == self._payload[SPEED_LOW]: self._speed = SPEED_LOW elif payload == self._payload[SPEED_MEDIUM]: @@ -247,9 +247,9 @@ class MqttFan(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._speed = SPEED_OFF @callback - def oscillation_received(topic, payload, qos): + def oscillation_received(msg): """Handle new received MQTT message for the oscillation.""" - payload = templates[OSCILLATION](payload) + payload = templates[OSCILLATION](msg.payload) if payload == self._payload[OSCILLATE_ON_PAYLOAD]: self._oscillation = True elif payload == self._payload[OSCILLATE_OFF_PAYLOAD]: diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index 256e0f46d85..a985a707485 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -254,11 +254,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, last_state = await self.async_get_last_state() @callback - def state_received(topic, payload, qos): + def state_received(msg): """Handle new MQTT messages.""" - payload = templates[CONF_STATE](payload) + payload = templates[CONF_STATE](msg.payload) if not payload: - _LOGGER.debug("Ignoring empty state message from '%s'", topic) + _LOGGER.debug("Ignoring empty state message from '%s'", + msg.topic) return if payload == self._payload['on']: @@ -276,12 +277,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._state = last_state.state == STATE_ON @callback - def brightness_received(topic, payload, qos): + def brightness_received(msg): """Handle new MQTT messages for the brightness.""" - payload = templates[CONF_BRIGHTNESS](payload) + payload = templates[CONF_BRIGHTNESS](msg.payload) if not payload: _LOGGER.debug("Ignoring empty brightness message from '%s'", - topic) + msg.topic) return device_value = float(payload) @@ -305,11 +306,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._brightness = None @callback - def rgb_received(topic, payload, qos): + def rgb_received(msg): """Handle new MQTT messages for RGB.""" - payload = templates[CONF_RGB](payload) + payload = templates[CONF_RGB](msg.payload) if not payload: - _LOGGER.debug("Ignoring empty rgb message from '%s'", topic) + _LOGGER.debug("Ignoring empty rgb message from '%s'", + msg.topic) return rgb = [int(val) for val in payload.split(',')] @@ -333,12 +335,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._hs = (0, 0) @callback - def color_temp_received(topic, payload, qos): + def color_temp_received(msg): """Handle new MQTT messages for color temperature.""" - payload = templates[CONF_COLOR_TEMP](payload) + payload = templates[CONF_COLOR_TEMP](msg.payload) if not payload: _LOGGER.debug("Ignoring empty color temp message from '%s'", - topic) + msg.topic) return self._color_temp = int(payload) @@ -359,11 +361,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._color_temp = None @callback - def effect_received(topic, payload, qos): + def effect_received(msg): """Handle new MQTT messages for effect.""" - payload = templates[CONF_EFFECT](payload) + payload = templates[CONF_EFFECT](msg.payload) if not payload: - _LOGGER.debug("Ignoring empty effect message from '%s'", topic) + _LOGGER.debug("Ignoring empty effect message from '%s'", + msg.topic) return self._effect = payload @@ -384,11 +387,11 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._effect = None @callback - def hs_received(topic, payload, qos): + def hs_received(msg): """Handle new MQTT messages for hs color.""" - payload = templates[CONF_HS](payload) + payload = templates[CONF_HS](msg.payload) if not payload: - _LOGGER.debug("Ignoring empty hs message from '%s'", topic) + _LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic) return try: @@ -412,12 +415,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._hs = (0, 0) @callback - def white_value_received(topic, payload, qos): + def white_value_received(msg): """Handle new MQTT messages for white value.""" - payload = templates[CONF_WHITE_VALUE](payload) + payload = templates[CONF_WHITE_VALUE](msg.payload) if not payload: _LOGGER.debug("Ignoring empty white value message from '%s'", - topic) + msg.topic) return device_value = float(payload) @@ -441,12 +444,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, self._white_value = None @callback - def xy_received(topic, payload, qos): + def xy_received(msg): """Handle new MQTT messages for xy color.""" - payload = templates[CONF_XY](payload) + payload = templates[CONF_XY](msg.payload) if not payload: _LOGGER.debug("Ignoring empty xy-color message from '%s'", - topic) + msg.topic) return xy_color = [float(val) for val in payload.split(',')] diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index df3aa7fe89e..12f688afbf7 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -201,9 +201,9 @@ class MqttLightJson(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, last_state = await self.async_get_last_state() @callback - def state_received(topic, payload, qos): + def state_received(msg): """Handle new MQTT messages.""" - values = json.loads(payload) + values = json.loads(msg.payload) if values['state'] == 'ON': self._state = True diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py index 0773a0cf05d..27c1fb00441 100644 --- a/homeassistant/components/mqtt/light/schema_template.py +++ b/homeassistant/components/mqtt/light/schema_template.py @@ -188,10 +188,10 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, last_state = await self.async_get_last_state() @callback - def state_received(topic, payload, qos): + def state_received(msg): """Handle new MQTT messages.""" state = self._templates[CONF_STATE_TEMPLATE].\ - async_render_with_possible_json_value(payload) + async_render_with_possible_json_value(msg.payload) if state == STATE_ON: self._state = True elif state == STATE_OFF: @@ -203,7 +203,7 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, try: self._brightness = int( self._templates[CONF_BRIGHTNESS_TEMPLATE]. - async_render_with_possible_json_value(payload) + async_render_with_possible_json_value(msg.payload) ) except ValueError: _LOGGER.warning("Invalid brightness value received") @@ -212,7 +212,7 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, try: self._color_temp = int( self._templates[CONF_COLOR_TEMP_TEMPLATE]. - async_render_with_possible_json_value(payload) + async_render_with_possible_json_value(msg.payload) ) except ValueError: _LOGGER.warning("Invalid color temperature value received") @@ -221,13 +221,13 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, try: red = int( self._templates[CONF_RED_TEMPLATE]. - async_render_with_possible_json_value(payload)) + async_render_with_possible_json_value(msg.payload)) green = int( self._templates[CONF_GREEN_TEMPLATE]. - async_render_with_possible_json_value(payload)) + async_render_with_possible_json_value(msg.payload)) blue = int( self._templates[CONF_BLUE_TEMPLATE]. - async_render_with_possible_json_value(payload)) + async_render_with_possible_json_value(msg.payload)) self._hs = color_util.color_RGB_to_hs(red, green, blue) except ValueError: _LOGGER.warning("Invalid color value received") @@ -236,14 +236,14 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, try: self._white_value = int( self._templates[CONF_WHITE_VALUE_TEMPLATE]. - async_render_with_possible_json_value(payload) + async_render_with_possible_json_value(msg.payload) ) except ValueError: _LOGGER.warning('Invalid white value received') if self._templates[CONF_EFFECT_TEMPLATE] is not None: effect = self._templates[CONF_EFFECT_TEMPLATE].\ - async_render_with_possible_json_value(payload) + async_render_with_possible_json_value(msg.payload) if effect in self._config.get(CONF_EFFECT_LIST): self._effect = effect diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index c8f1bedeeff..d9adc37d79a 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -120,8 +120,9 @@ class MqttLock(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, value_template.hass = self.hass @callback - def message_received(topic, payload, qos): + def message_received(msg): """Handle new MQTT messages.""" + payload = msg.payload if value_template is not None: payload = value_template.async_render_with_possible_json_value( payload) diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 0a507b1bc4f..c6ef3344fcf 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -133,8 +133,9 @@ class MqttSensor(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, template.hass = self.hass @callback - def message_received(topic, payload, qos): + def message_received(msg): """Handle new MQTT messages.""" + payload = msg.payload # auto-expire enabled? expire_after = self._config.get(CONF_EXPIRE_AFTER) if expire_after is not None and expire_after > 0: diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index 50243274bfb..de7da6b7249 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -143,8 +143,9 @@ class MqttSwitch(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, template.hass = self.hass @callback - def state_message_received(topic, payload, qos): + def state_message_received(msg): """Handle new MQTT state messages.""" + payload = msg.payload if template is not None: payload = template.async_render_with_possible_json_value( payload) diff --git a/homeassistant/components/mqtt/vacuum.py b/homeassistant/components/mqtt/vacuum.py index 081bf5fc583..eb7e78b6254 100644 --- a/homeassistant/components/mqtt/vacuum.py +++ b/homeassistant/components/mqtt/vacuum.py @@ -284,45 +284,45 @@ class MqttVacuum(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, tpl.hass = self.hass @callback - def message_received(topic, payload, qos): + def message_received(msg): """Handle new MQTT message.""" - if topic == self._state_topics[CONF_BATTERY_LEVEL_TOPIC] and \ + if msg.topic == self._state_topics[CONF_BATTERY_LEVEL_TOPIC] and \ self._templates[CONF_BATTERY_LEVEL_TEMPLATE]: battery_level = self._templates[CONF_BATTERY_LEVEL_TEMPLATE]\ .async_render_with_possible_json_value( - payload, error_value=None) + msg.payload, error_value=None) if battery_level is not None: self._battery_level = int(battery_level) - if topic == self._state_topics[CONF_CHARGING_TOPIC] and \ + if msg.topic == self._state_topics[CONF_CHARGING_TOPIC] and \ self._templates[CONF_CHARGING_TEMPLATE]: charging = self._templates[CONF_CHARGING_TEMPLATE]\ .async_render_with_possible_json_value( - payload, error_value=None) + msg.payload, error_value=None) if charging is not None: self._charging = cv.boolean(charging) - if topic == self._state_topics[CONF_CLEANING_TOPIC] and \ + if msg.topic == self._state_topics[CONF_CLEANING_TOPIC] and \ self._templates[CONF_CLEANING_TEMPLATE]: cleaning = self._templates[CONF_CLEANING_TEMPLATE]\ .async_render_with_possible_json_value( - payload, error_value=None) + msg.payload, error_value=None) if cleaning is not None: self._cleaning = cv.boolean(cleaning) - if topic == self._state_topics[CONF_DOCKED_TOPIC] and \ + if msg.topic == self._state_topics[CONF_DOCKED_TOPIC] and \ self._templates[CONF_DOCKED_TEMPLATE]: docked = self._templates[CONF_DOCKED_TEMPLATE]\ .async_render_with_possible_json_value( - payload, error_value=None) + msg.payload, error_value=None) if docked is not None: self._docked = cv.boolean(docked) - if topic == self._state_topics[CONF_ERROR_TOPIC] and \ + if msg.topic == self._state_topics[CONF_ERROR_TOPIC] and \ self._templates[CONF_ERROR_TEMPLATE]: error = self._templates[CONF_ERROR_TEMPLATE]\ .async_render_with_possible_json_value( - payload, error_value=None) + msg.payload, error_value=None) if error is not None: self._error = cv.string(error) @@ -338,11 +338,11 @@ class MqttVacuum(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate, else: self._status = "Stopped" - if topic == self._state_topics[CONF_FAN_SPEED_TOPIC] and \ + if msg.topic == self._state_topics[CONF_FAN_SPEED_TOPIC] and \ self._templates[CONF_FAN_SPEED_TEMPLATE]: fan_speed = self._templates[CONF_FAN_SPEED_TEMPLATE]\ .async_render_with_possible_json_value( - payload, error_value=None) + msg.payload, error_value=None) if fan_speed is not None: self._fan_speed = fan_speed diff --git a/homeassistant/components/mysensors/gateway.py b/homeassistant/components/mysensors/gateway.py index d4a52655d19..62ea20cbb91 100644 --- a/homeassistant/components/mysensors/gateway.py +++ b/homeassistant/components/mysensors/gateway.py @@ -98,9 +98,9 @@ async def _get_gateway(hass, config, gateway_conf, persistence_file): def sub_callback(topic, sub_cb, qos): """Call MQTT subscribe function.""" @callback - def internal_callback(*args): + def internal_callback(msg): """Call callback.""" - sub_cb(*args) + sub_cb(msg.topic, msg.payload, msg.qos) hass.async_create_task( mqtt.async_subscribe(topic, internal_callback, qos)) diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 81941173d68..5c441a68bea 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -316,8 +316,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert 'test-topic' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert 'test-topic' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload unsub() @@ -343,8 +343,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert 'test-topic/bier/on' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert 'test-topic/bier/on' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_level_wildcard_no_subtree_match(self): """Test the subscription of wildcard topics.""" @@ -372,8 +372,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert 'test-topic/bier/on' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert 'test-topic/bier/on' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_subtree_wildcard_root_topic(self): """Test the subscription of wildcard topics.""" @@ -383,8 +383,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert 'test-topic' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert 'test-topic' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_subtree_wildcard_no_match(self): """Test the subscription of wildcard topics.""" @@ -403,8 +403,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert 'hi/test-topic' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert 'hi/test-topic' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self): """Test the subscription of wildcard topics.""" @@ -414,8 +414,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert 'hi/test-topic/here-iam' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert 'hi/test-topic/here-iam' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self): """Test the subscription of wildcard topics.""" @@ -443,8 +443,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert '$test-topic/subtree/on' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert '$test-topic/subtree/on' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_sys_root_and_wildcard_topic(self): """Test the subscription of $ root and wildcard topics.""" @@ -454,8 +454,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert '$test-topic/some-topic' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert '$test-topic/some-topic' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_topic_sys_root_and_wildcard_subtree_topic(self): """Test the subscription of $ root and wildcard subtree topics.""" @@ -466,8 +466,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.hass.block_till_done() assert 1 == len(self.calls) - assert '$test-topic/subtree/some-topic' == self.calls[0][0] - assert 'test-payload' == self.calls[0][1] + assert '$test-topic/subtree/some-topic' == self.calls[0][0].topic + assert 'test-payload' == self.calls[0][0].payload def test_subscribe_special_characters(self): """Test the subscription to topics with special characters.""" @@ -479,8 +479,8 @@ class TestMQTTCallbacks(unittest.TestCase): fire_mqtt_message(self.hass, topic, payload) self.hass.block_till_done() assert 1 == len(self.calls) - assert topic == self.calls[0][0] - assert payload == self.calls[0][1] + assert topic == self.calls[0][0].topic + assert payload == self.calls[0][0].payload def test_mqtt_failed_connection_results_in_disconnect(self): """Test if connection failure leads to disconnect.""" diff --git a/tests/components/mqtt/test_subscription.py b/tests/components/mqtt/test_subscription.py index b4b005d0d1e..cd274079e01 100644 --- a/tests/components/mqtt/test_subscription.py +++ b/tests/components/mqtt/test_subscription.py @@ -35,8 +35,8 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog): async_fire_mqtt_message(hass, 'test-topic1', 'test-payload1') await hass.async_block_till_done() assert 1 == len(calls1) - assert 'test-topic1' == calls1[0][0] - assert 'test-payload1' == calls1[0][1] + assert 'test-topic1' == calls1[0][0].topic + assert 'test-payload1' == calls1[0][0].payload assert 0 == len(calls2) async_fire_mqtt_message(hass, 'test-topic2', 'test-payload2') @@ -44,8 +44,8 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog): await hass.async_block_till_done() assert 1 == len(calls1) assert 1 == len(calls2) - assert 'test-topic2' == calls2[0][0] - assert 'test-payload2' == calls2[0][1] + assert 'test-topic2' == calls2[0][0].topic + assert 'test-payload2' == calls2[0][0].payload await async_unsubscribe_topics(hass, sub_state) @@ -108,8 +108,8 @@ async def test_modify_topics(hass, mqtt_mock, caplog): await hass.async_block_till_done() await hass.async_block_till_done() assert 2 == len(calls1) - assert 'test-topic1_1' == calls1[1][0] - assert 'test-payload' == calls1[1][1] + assert 'test-topic1_1' == calls1[1][0].topic + assert 'test-payload' == calls1[1][0].payload assert 1 == len(calls2) await async_unsubscribe_topics(hass, sub_state)