Fix MQTT retained message not being re-dispatched (#12004)

* Fix MQTT retained message not being re-dispatched

* Fix tests

* Use paho-mqtt for retained messages

* Improve code style

* Store list of subscribers

* Fix lint error

* Adhere to Home Assistant's logging standard

"Try to avoid brackets and additional quotes around the output to make it easier for users to parse the log."
 - https://home-assistant.io/developers/development_guidelines/

* Add reconnect tests

* Fix lint error

* Introduce Subscription

Tests still need to be updated

* Use namedtuple for MQTT messages

... And fix issues

Accessing the config manually at runtime isn't ideal

* Fix MQTT __init__.py tests

* Updated usage of Mocks
* Moved tests that were testing subscriptions out of the MQTTComponent test, because of how mock.patch was used
* Adjusted the remaining tests for the MQTT clients new behavior - e.g. self.progress was removed
* Updated the async_fire_mqtt_message helper

*  Update MQTT tests

* Re-introduce the MQTT subscriptions through the dispatcher for tests - quite ugly though...  🚧
* Update fixtures to use our new MQTT mock 🎨

* 📝 Update base code according to comments

* 🔨 Adjust MQTT test base

* 🔨 Update other MQTT tests

* 🍎 Fix carriage return in source files

Apparently test_mqtt_json.py and test_mqtt_template.py were written on Windows. In order to not mess up the diff, I'll just redo the carriage return.

* 🎨 Remove unused import

* 📝 Remove fire_mqtt_client_message

* 🐛 Fix using python 3.6 method

What's very interesting is that 3.4 didn't fail on travis...

* 🐛 Fix using assert directly
This commit is contained in:
Otto Winter 2018-02-11 18:17:58 +01:00 committed by Paulus Schoutsen
parent 17e5740a0c
commit b1c0cabe6c
15 changed files with 1531 additions and 1490 deletions

View file

@ -5,6 +5,10 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/ https://home-assistant.io/components/mqtt/
""" """
import asyncio import asyncio
from collections import namedtuple
from itertools import groupby
from typing import Optional
from operator import attrgetter
import logging import logging
import os import os
import socket import socket
@ -15,13 +19,12 @@ import requests.certs
import voluptuous as vol import voluptuous as vol
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.setup import async_prepare_setup_platform from homeassistant.setup import async_prepare_setup_platform
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.helpers import template, config_validation as cv from homeassistant.helpers import template, ConfigType, config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, dispatcher_send)
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.util.async import ( from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe) run_coroutine_threadsafe, run_callback_threadsafe)
@ -39,7 +42,6 @@ DOMAIN = 'mqtt'
DATA_MQTT = 'mqtt' DATA_MQTT = 'mqtt'
SERVICE_PUBLISH = 'publish' SERVICE_PUBLISH = 'publish'
SIGNAL_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
CONF_EMBEDDED = 'embedded' CONF_EMBEDDED = 'embedded'
CONF_BROKER = 'broker' CONF_BROKER = 'broker'
@ -173,7 +175,6 @@ MQTT_RW_PLATFORM_SCHEMA = MQTT_BASE_PLATFORM_SCHEMA.extend({
vol.Optional(CONF_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
}) })
# Service call validation schema # Service call validation schema
MQTT_PUBLISH_SCHEMA = vol.Schema({ MQTT_PUBLISH_SCHEMA = vol.Schema({
vol.Required(ATTR_TOPIC): valid_publish_topic, vol.Required(ATTR_TOPIC): valid_publish_topic,
@ -221,32 +222,13 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
@bind_hass @bind_hass
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
encoding='utf-8'): encoding='utf-8'):
"""Subscribe to an MQTT topic.""" """Subscribe to an MQTT topic.
@callback
def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos):
"""Match subscribed MQTT topic."""
if not _match_topic(topic, dp_topic):
return
if encoding is not None: Call the return value to unsubscribe.
try: """
payload = dp_payload.decode(encoding) async_remove = \
_LOGGER.debug("Received message on %s: %s", dp_topic, payload) yield from hass.data[DATA_MQTT].async_subscribe(topic, msg_callback,
except (AttributeError, UnicodeDecodeError): qos, encoding)
_LOGGER.error("Illegal payload encoding %s from "
"MQTT topic: %s, Payload: %s",
encoding, dp_topic, dp_payload)
return
else:
_LOGGER.debug("Received binary message on %s", dp_topic)
payload = dp_payload
hass.async_run_job(msg_callback, dp_topic, payload, dp_qos)
async_remove = async_dispatcher_connect(
hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
return async_remove return async_remove
@ -308,7 +290,7 @@ def _async_setup_discovery(hass, config):
@asyncio.coroutine @asyncio.coroutine
def async_setup(hass, config): def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Start the MQTT protocol service.""" """Start the MQTT protocol service."""
conf = config.get(DOMAIN) conf = config.get(DOMAIN)
@ -351,8 +333,8 @@ def async_setup(hass, config):
return False return False
# For cloudmqtt.com, secured connection, auto fill in certificate # For cloudmqtt.com, secured connection, auto fill in certificate
if certificate is None and 19999 < port < 30000 and \ if (certificate is None and 19999 < port < 30000 and
broker.endswith('.cloudmqtt.com'): broker.endswith('.cloudmqtt.com')):
certificate = os.path.join(os.path.dirname(__file__), certificate = os.path.join(os.path.dirname(__file__),
'addtrustexternalcaroot.crt') 'addtrustexternalcaroot.crt')
@ -360,8 +342,12 @@ def async_setup(hass, config):
if certificate == 'auto': if certificate == 'auto':
certificate = requests.certs.where() certificate = requests.certs.where()
will_message = conf.get(CONF_WILL_MESSAGE) will_message = None
birth_message = conf.get(CONF_BIRTH_MESSAGE) if conf.get(CONF_WILL_MESSAGE) is not None:
will_message = Message(**conf.get(CONF_WILL_MESSAGE))
birth_message = None
if conf.get(CONF_BIRTH_MESSAGE) is not None:
birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE))
# Be able to override versions other than TLSv1.0 under Python3.6 # Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version = conf.get(CONF_TLS_VERSION) conf_tls_version = conf.get(CONF_TLS_VERSION)
@ -414,8 +400,8 @@ def async_setup(hass, config):
template.Template(payload_template, hass).async_render() template.Template(payload_template, hass).async_render()
except template.jinja2.TemplateError as exc: except template.jinja2.TemplateError as exc:
_LOGGER.error( _LOGGER.error(
"Unable to publish to '%s': rendering payload template of " "Unable to publish to %s: rendering payload template of "
"'%s' failed because %s", "%s failed because %s",
msg_topic, payload_template, exc) msg_topic, payload_template, exc)
return return
@ -432,13 +418,21 @@ def async_setup(hass, config):
return True return True
Subscription = namedtuple('Subscription',
['topic', 'callback', 'qos', 'encoding'])
Subscription.__new__.__defaults__ = (0, 'utf-8')
Message = namedtuple('Message', ['topic', 'payload', 'qos', 'retain'])
Message.__new__.__defaults__ = (0, False)
class MQTT(object): class MQTT(object):
"""Home Assistant MQTT client.""" """Home Assistant MQTT client."""
def __init__(self, hass, broker, port, client_id, keepalive, username, def __init__(self, hass, broker, port, client_id, keepalive, username,
password, certificate, client_key, client_cert, password, certificate, client_key, client_cert,
tls_insecure, protocol, will_message, birth_message, tls_insecure, protocol, will_message: Optional[Message],
tls_version): birth_message: Optional[Message], tls_version):
"""Initialize Home Assistant MQTT client.""" """Initialize Home Assistant MQTT client."""
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
@ -446,9 +440,7 @@ class MQTT(object):
self.broker = broker self.broker = broker
self.port = port self.port = port
self.keepalive = keepalive self.keepalive = keepalive
self.wanted_topics = {} self.subscriptions = []
self.subscribed_topics = {}
self.progress = {}
self.birth_message = birth_message self.birth_message = birth_message
self._mqttc = None self._mqttc = None
self._paho_lock = asyncio.Lock(loop=hass.loop) self._paho_lock = asyncio.Lock(loop=hass.loop)
@ -474,17 +466,12 @@ class MQTT(object):
if tls_insecure is not None: if tls_insecure is not None:
self._mqttc.tls_insecure_set(tls_insecure) self._mqttc.tls_insecure_set(tls_insecure)
self._mqttc.on_subscribe = self._mqtt_on_subscribe
self._mqttc.on_unsubscribe = self._mqtt_on_unsubscribe
self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message self._mqttc.on_message = self._mqtt_on_message
if will_message: if will_message:
self._mqttc.will_set(will_message.get(ATTR_TOPIC), self._mqttc.will_set(*will_message)
will_message.get(ATTR_PAYLOAD),
will_message.get(ATTR_QOS),
will_message.get(ATTR_RETAIN))
@asyncio.coroutine @asyncio.coroutine
def async_publish(self, topic, payload, qos, retain): def async_publish(self, topic, payload, qos, retain):
@ -526,36 +513,53 @@ class MQTT(object):
return self.hass.async_add_job(stop) return self.hass.async_add_job(stop)
@asyncio.coroutine @asyncio.coroutine
def async_subscribe(self, topic, qos): def async_subscribe(self, topic, msg_callback, qos, encoding):
"""Subscribe to a topic. """Set up a subscription to a topic with the provided qos.
This method is a coroutine. This method is a coroutine.
""" """
if not isinstance(topic, str): if not isinstance(topic, str):
raise HomeAssistantError("topic need to be a string!") raise HomeAssistantError("topic needs to be a string!")
with (yield from self._paho_lock): subscription = Subscription(topic, msg_callback, qos, encoding)
if topic in self.subscribed_topics: self.subscriptions.append(subscription)
yield from self._async_perform_subscription(topic, qos)
@callback
def async_remove():
"""Remove subscription."""
if subscription not in self.subscriptions:
raise HomeAssistantError("Can't remove subscription twice")
self.subscriptions.remove(subscription)
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return return
self.wanted_topics[topic] = qos self.hass.async_add_job(self._async_unsubscribe(topic))
result, mid = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos)
_raise_on_error(result) return async_remove
self.progress[mid] = topic
@asyncio.coroutine @asyncio.coroutine
def async_unsubscribe(self, topic): def _async_unsubscribe(self, topic):
"""Unsubscribe from topic. """Unsubscribe from a topic.
This method is a coroutine. This method is a coroutine.
""" """
self.wanted_topics.pop(topic, None) with (yield from self._paho_lock):
result, mid = yield from self.hass.async_add_job( result, _ = yield from self.hass.async_add_job(
self._mqttc.unsubscribe, topic) self._mqttc.unsubscribe, topic)
_raise_on_error(result)
_raise_on_error(result) @asyncio.coroutine
self.progress[mid] = topic def _async_perform_subscription(self, topic, qos):
"""Perform a paho-mqtt subscription."""
_LOGGER.debug("Subscribing to %s", topic)
with (yield from self._paho_lock):
result, _ = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos)
_raise_on_error(result)
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code): def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code):
"""On connect callback. """On connect callback.
@ -571,50 +575,50 @@ class MQTT(object):
self._mqttc.disconnect() self._mqttc.disconnect()
return return
self.progress = {} # Group subscriptions to only re-subscribe once for each topic.
self.subscribed_topics = {} keyfunc = attrgetter('topic')
for topic, qos in self.wanted_topics.items(): for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
self.hass.add_job(self.async_subscribe, topic, qos) keyfunc):
# Re-subscribe with the highest requested qos
max_qos = max(subscription.qos for subscription in subs)
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
if self.birth_message: if self.birth_message:
self.hass.add_job(self.async_publish( self.hass.add_job(self.async_publish(*self.birth_message))
self.birth_message.get(ATTR_TOPIC),
self.birth_message.get(ATTR_PAYLOAD),
self.birth_message.get(ATTR_QOS),
self.birth_message.get(ATTR_RETAIN)))
def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
"""Subscribe successful callback."""
topic = self.progress.pop(mid, None)
if topic is None:
return
self.subscribed_topics[topic] = granted_qos[0]
def _mqtt_on_message(self, _mqttc, _userdata, msg): def _mqtt_on_message(self, _mqttc, _userdata, msg):
"""Message received callback.""" """Message received callback."""
dispatcher_send( self.hass.async_add_job(self._mqtt_handle_message, msg)
self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, msg.payload,
msg.qos
)
def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos): @callback
"""Unsubscribe successful callback.""" def _mqtt_handle_message(self, msg):
topic = self.progress.pop(mid, None) _LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload)
if topic is None:
return for subscription in self.subscriptions:
self.subscribed_topics.pop(topic, None) if not _match_topic(subscription.topic, msg.topic):
continue
payload = msg.payload
if subscription.encoding is not None:
try:
payload = msg.payload.decode(subscription.encoding)
except (AttributeError, UnicodeDecodeError):
_LOGGER.warning("Can't decode payload %s on %s "
"with encoding %s",
msg.payload, msg.topic,
subscription.encoding)
return
self.hass.async_run_job(subscription.callback,
msg.topic, payload, msg.qos)
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code): def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
"""Disconnected callback.""" """Disconnected callback."""
self.progress = {}
self.subscribed_topics = {}
# When disconnected because of calling disconnect() # When disconnected because of calling disconnect()
if result_code == 0: if result_code == 0:
return return
tries = 0 tries = 0
wait_time = 0
while True: while True:
try: try:
@ -693,7 +697,7 @@ class MqttAvailability(Entity):
if self._availability_topic is not None: if self._availability_topic is not None:
yield from async_subscribe( yield from async_subscribe(
self.hass, self._availability_topic, self.hass, self._availability_topic,
availability_message_received, self. _availability_qos) availability_message_received, self._availability_qos)
@property @property
def available(self) -> bool: def available(self) -> bool:

View file

@ -15,7 +15,7 @@ from homeassistant import core as ha, loader
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.helpers import ( from homeassistant.helpers import (
intent, dispatcher, entity, restore_state, entity_registry, intent, entity, restore_state, entity_registry,
entity_platform) entity_platform)
from homeassistant.util.unit_system import METRIC_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM
import homeassistant.util.dt as date_util import homeassistant.util.dt as date_util
@ -214,13 +214,12 @@ def async_mock_intent(hass, intent_typ):
@ha.callback @ha.callback
def async_fire_mqtt_message(hass, topic, payload, qos=0): def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
"""Fire the MQTT message.""" """Fire the MQTT message."""
if isinstance(payload, str): if isinstance(payload, str):
payload = payload.encode('utf-8') payload = payload.encode('utf-8')
dispatcher.async_dispatcher_send( msg = mqtt.Message(topic, payload, qos, retain)
hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, topic, hass.async_run_job(hass.data['mqtt']._mqtt_on_message, None, None, msg)
payload, qos)
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
@ -293,16 +292,25 @@ def mock_http_component_app(hass, api_password=None):
@asyncio.coroutine @asyncio.coroutine
def async_mock_mqtt_component(hass): def async_mock_mqtt_component(hass, config=None):
"""Mock the MQTT component.""" """Mock the MQTT component."""
with patch('homeassistant.components.mqtt.MQTT') as mock_mqtt: if config is None:
mock_mqtt().async_connect.return_value = mock_coro(True) config = {mqtt.CONF_BROKER: 'mock-broker'}
yield from async_setup_component(hass, mqtt.DOMAIN, {
mqtt.DOMAIN: { with patch('paho.mqtt.client.Client') as mock_client:
mqtt.CONF_BROKER: 'mock-broker', mock_client().connect.return_value = 0
} mock_client().subscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
result = yield from async_setup_component(hass, mqtt.DOMAIN, {
mqtt.DOMAIN: config
}) })
return mock_mqtt assert result
hass.data['mqtt'] = MagicMock(spec_set=hass.data['mqtt'],
wraps=hass.data['mqtt'])
return hass.data['mqtt']
mock_mqtt_component = threadsafe_coroutine_factory(async_mock_mqtt_component) mock_mqtt_component = threadsafe_coroutine_factory(async_mock_mqtt_component)

View file

@ -1395,53 +1395,60 @@ class TestAlarmControlPanelManualMqtt(unittest.TestCase):
# Component should send disarmed alarm state on startup # Component should send disarmed alarm state on startup
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_DISARMED, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_DISARMED, 0, True)
self.mock_publish.async_publish.reset_mock()
# Arm in home mode # Arm in home mode
alarm_control_panel.alarm_arm_home(self.hass) alarm_control_panel.alarm_arm_home(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_PENDING, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_PENDING, 0, True)
self.mock_publish.async_publish.reset_mock()
# Fast-forward a little bit # Fast-forward a little bit
future = dt_util.utcnow() + timedelta(seconds=1) future = dt_util.utcnow() + timedelta(seconds=1)
with patch(('homeassistant.components.alarm_control_panel.manual_mqtt.' with patch(('homeassistant.components.alarm_control_panel.manual_mqtt.'
'dt_util.utcnow'), return_value=future): 'dt_util.utcnow'), return_value=future):
fire_time_changed(self.hass, future) fire_time_changed(self.hass, future)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_ARMED_HOME, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_ARMED_HOME, 0, True)
self.mock_publish.async_publish.reset_mock()
# Arm in away mode # Arm in away mode
alarm_control_panel.alarm_arm_away(self.hass) alarm_control_panel.alarm_arm_away(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_PENDING, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_PENDING, 0, True)
self.mock_publish.async_publish.reset_mock()
# Fast-forward a little bit # Fast-forward a little bit
future = dt_util.utcnow() + timedelta(seconds=1) future = dt_util.utcnow() + timedelta(seconds=1)
with patch(('homeassistant.components.alarm_control_panel.manual_mqtt.' with patch(('homeassistant.components.alarm_control_panel.manual_mqtt.'
'dt_util.utcnow'), return_value=future): 'dt_util.utcnow'), return_value=future):
fire_time_changed(self.hass, future) fire_time_changed(self.hass, future)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_ARMED_AWAY, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_ARMED_AWAY, 0, True)
self.mock_publish.async_publish.reset_mock()
# Arm in night mode # Arm in night mode
alarm_control_panel.alarm_arm_night(self.hass) alarm_control_panel.alarm_arm_night(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_PENDING, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_PENDING, 0, True)
self.mock_publish.async_publish.reset_mock()
# Fast-forward a little bit # Fast-forward a little bit
future = dt_util.utcnow() + timedelta(seconds=1) future = dt_util.utcnow() + timedelta(seconds=1)
with patch(('homeassistant.components.alarm_control_panel.manual_mqtt.' with patch(('homeassistant.components.alarm_control_panel.manual_mqtt.'
'dt_util.utcnow'), return_value=future): 'dt_util.utcnow'), return_value=future):
fire_time_changed(self.hass, future) fire_time_changed(self.hass, future)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_ARMED_NIGHT, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_ARMED_NIGHT, 0, True)
self.mock_publish.async_publish.reset_mock()
# Disarm # Disarm
alarm_control_panel.alarm_disarm(self.hass) alarm_control_panel.alarm_disarm(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/state', STATE_ALARM_DISARMED, 0, True), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/state', STATE_ALARM_DISARMED, 0, True)

View file

@ -106,8 +106,8 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_arm_home(self.hass) alarm_control_panel.alarm_arm_home(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/command', 'ARM_HOME', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/command', 'ARM_HOME', 0, False)
def test_arm_home_not_publishes_mqtt_with_invalid_code(self): def test_arm_home_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code.""" """Test not publishing of MQTT messages with invalid code."""
@ -139,8 +139,8 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_arm_away(self.hass) alarm_control_panel.alarm_arm_away(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/command', 'ARM_AWAY', 0, False)
def test_arm_away_not_publishes_mqtt_with_invalid_code(self): def test_arm_away_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code.""" """Test not publishing of MQTT messages with invalid code."""
@ -172,8 +172,8 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_disarm(self.hass) alarm_control_panel.alarm_disarm(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/command', 'DISARM', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'alarm/command', 'DISARM', 0, False)
def test_disarm_not_publishes_mqtt_with_invalid_code(self): def test_disarm_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code.""" """Test not publishing of MQTT messages with invalid code."""

View file

@ -104,8 +104,8 @@ class TestMQTTClimate(unittest.TestCase):
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual("cool", state.attributes.get('operation_mode')) self.assertEqual("cool", state.attributes.get('operation_mode'))
self.assertEqual("cool", state.state) self.assertEqual("cool", state.state)
self.assertEqual(('mode-topic', 'cool', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'mode-topic', 'cool', 0, False)
def test_set_operation_pessimistic(self): def test_set_operation_pessimistic(self):
"""Test setting operation mode in pessimistic mode.""" """Test setting operation mode in pessimistic mode."""
@ -178,8 +178,8 @@ class TestMQTTClimate(unittest.TestCase):
self.assertEqual("low", state.attributes.get('fan_mode')) self.assertEqual("low", state.attributes.get('fan_mode'))
climate.set_fan_mode(self.hass, 'high', ENTITY_CLIMATE) climate.set_fan_mode(self.hass, 'high', ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('fan-mode-topic', 'high', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'fan-mode-topic', 'high', 0, False)
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('high', state.attributes.get('fan_mode')) self.assertEqual('high', state.attributes.get('fan_mode'))
@ -226,8 +226,8 @@ class TestMQTTClimate(unittest.TestCase):
self.assertEqual("off", state.attributes.get('swing_mode')) self.assertEqual("off", state.attributes.get('swing_mode'))
climate.set_swing_mode(self.hass, 'on', ENTITY_CLIMATE) climate.set_swing_mode(self.hass, 'on', ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('swing-mode-topic', 'on', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'swing-mode-topic', 'on', 0, False)
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual("on", state.attributes.get('swing_mode')) self.assertEqual("on", state.attributes.get('swing_mode'))
@ -239,15 +239,16 @@ class TestMQTTClimate(unittest.TestCase):
self.assertEqual(21, state.attributes.get('temperature')) self.assertEqual(21, state.attributes.get('temperature'))
climate.set_operation_mode(self.hass, 'heat', ENTITY_CLIMATE) climate.set_operation_mode(self.hass, 'heat', ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('mode-topic', 'heat', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'mode-topic', 'heat', 0, False)
self.mock_publish.async_publish.reset_mock()
climate.set_temperature(self.hass, temperature=47, climate.set_temperature(self.hass, temperature=47,
entity_id=ENTITY_CLIMATE) entity_id=ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual(47, state.attributes.get('temperature')) self.assertEqual(47, state.attributes.get('temperature'))
self.assertEqual(('temperature-topic', 47, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'temperature-topic', 47, 0, False)
def test_set_target_temperature_pessimistic(self): def test_set_target_temperature_pessimistic(self):
"""Test setting the target temperature.""" """Test setting the target temperature."""
@ -328,15 +329,16 @@ class TestMQTTClimate(unittest.TestCase):
self.assertEqual('off', state.attributes.get('away_mode')) self.assertEqual('off', state.attributes.get('away_mode'))
climate.set_away_mode(self.hass, True, ENTITY_CLIMATE) climate.set_away_mode(self.hass, True, ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('away-mode-topic', 'AN', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'away-mode-topic', 'AN', 0, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('on', state.attributes.get('away_mode')) self.assertEqual('on', state.attributes.get('away_mode'))
climate.set_away_mode(self.hass, False, ENTITY_CLIMATE) climate.set_away_mode(self.hass, False, ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('away-mode-topic', 'AUS', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'away-mode-topic', 'AUS', 0, False)
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('off', state.attributes.get('away_mode')) self.assertEqual('off', state.attributes.get('away_mode'))
@ -372,15 +374,16 @@ class TestMQTTClimate(unittest.TestCase):
self.assertEqual(None, state.attributes.get('hold_mode')) self.assertEqual(None, state.attributes.get('hold_mode'))
climate.set_hold_mode(self.hass, 'on', ENTITY_CLIMATE) climate.set_hold_mode(self.hass, 'on', ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('hold-topic', 'on', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'hold-topic', 'on', 0, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('on', state.attributes.get('hold_mode')) self.assertEqual('on', state.attributes.get('hold_mode'))
climate.set_hold_mode(self.hass, 'off', ENTITY_CLIMATE) climate.set_hold_mode(self.hass, 'off', ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('hold-topic', 'off', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'hold-topic', 'off', 0, False)
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('off', state.attributes.get('hold_mode')) self.assertEqual('off', state.attributes.get('hold_mode'))
@ -421,15 +424,16 @@ class TestMQTTClimate(unittest.TestCase):
self.assertEqual('off', state.attributes.get('aux_heat')) self.assertEqual('off', state.attributes.get('aux_heat'))
climate.set_aux_heat(self.hass, True, ENTITY_CLIMATE) climate.set_aux_heat(self.hass, True, ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('aux-topic', 'ON', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'aux-topic', 'ON', 0, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('on', state.attributes.get('aux_heat')) self.assertEqual('on', state.attributes.get('aux_heat'))
climate.set_aux_heat(self.hass, False, ENTITY_CLIMATE) climate.set_aux_heat(self.hass, False, ENTITY_CLIMATE)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('aux-topic', 'OFF', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'aux-topic', 'OFF', 0, False)
state = self.hass.states.get(ENTITY_CLIMATE) state = self.hass.states.get(ENTITY_CLIMATE)
self.assertEqual('off', state.attributes.get('aux_heat')) self.assertEqual('off', state.attributes.get('aux_heat'))

View file

@ -116,16 +116,17 @@ class TestCoverMQTT(unittest.TestCase):
cover.open_cover(self.hass, 'cover.test') cover.open_cover(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'OPEN', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'OPEN', 0, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_OPEN, state.state) self.assertEqual(STATE_OPEN, state.state)
cover.close_cover(self.hass, 'cover.test') cover.close_cover(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'CLOSE', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'CLOSE', 0, False)
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_CLOSED, state.state) self.assertEqual(STATE_CLOSED, state.state)
@ -147,8 +148,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.open_cover(self.hass, 'cover.test') cover.open_cover(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'OPEN', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'OPEN', 2, False)
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state) self.assertEqual(STATE_UNKNOWN, state.state)
@ -170,8 +171,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.close_cover(self.hass, 'cover.test') cover.close_cover(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'CLOSE', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'CLOSE', 2, False)
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state) self.assertEqual(STATE_UNKNOWN, state.state)
@ -193,8 +194,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.stop_cover(self.hass, 'cover.test') cover.stop_cover(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'STOP', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'STOP', 2, False)
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state) self.assertEqual(STATE_UNKNOWN, state.state)
@ -295,8 +296,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.set_cover_position(self.hass, 100, 'cover.test') cover.set_cover_position(self.hass, 100, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('position-topic', '38', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'position-topic', '38', 0, False)
def test_set_position_untemplated(self): def test_set_position_untemplated(self):
"""Test setting cover position via template.""" """Test setting cover position via template."""
@ -316,8 +317,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.set_cover_position(self.hass, 62, 'cover.test') cover.set_cover_position(self.hass, 62, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('position-topic', 62, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'position-topic', 62, 0, False)
def test_no_command_topic(self): def test_no_command_topic(self):
"""Test with no command topic.""" """Test with no command topic."""
@ -401,14 +402,15 @@ class TestCoverMQTT(unittest.TestCase):
cover.open_cover_tilt(self.hass, 'cover.test') cover.open_cover_tilt(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('tilt-command-topic', 100, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'tilt-command-topic', 100, 0, False)
self.mock_publish.async_publish.reset_mock()
cover.close_cover_tilt(self.hass, 'cover.test') cover.close_cover_tilt(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('tilt-command-topic', 0, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'tilt-command-topic', 0, 0, False)
def test_tilt_given_value(self): def test_tilt_given_value(self):
"""Test tilting to a given value.""" """Test tilting to a given value."""
@ -432,14 +434,15 @@ class TestCoverMQTT(unittest.TestCase):
cover.open_cover_tilt(self.hass, 'cover.test') cover.open_cover_tilt(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('tilt-command-topic', 400, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'tilt-command-topic', 400, 0, False)
self.mock_publish.async_publish.reset_mock()
cover.close_cover_tilt(self.hass, 'cover.test') cover.close_cover_tilt(self.hass, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('tilt-command-topic', 125, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'tilt-command-topic', 125, 0, False)
def test_tilt_via_topic(self): def test_tilt_via_topic(self):
"""Test tilt by updating status via MQTT.""" """Test tilt by updating status via MQTT."""
@ -538,8 +541,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.set_cover_tilt_position(self.hass, 50, 'cover.test') cover.set_cover_tilt_position(self.hass, 50, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('tilt-command-topic', 50, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'tilt-command-topic', 50, 0, False)
def test_tilt_position_altered_range(self): def test_tilt_position_altered_range(self):
"""Test tilt via method invocation with altered range.""" """Test tilt via method invocation with altered range."""
@ -565,8 +568,8 @@ class TestCoverMQTT(unittest.TestCase):
cover.set_cover_tilt_position(self.hass, 50, 'cover.test') cover.set_cover_tilt_position(self.hass, 50, 'cover.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('tilt-command-topic', 25, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'tilt-command-topic', 25, 0, False)
def test_find_percentage_in_range_defaults(self): def test_find_percentage_in_range_defaults(self):
"""Test find percentage in range with default range.""" """Test find percentage in range with default range."""

View file

@ -492,16 +492,18 @@ class TestLightMQTT(unittest.TestCase):
light.turn_on(self.hass, 'light.test') light.turn_on(self.hass, 'light.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'on', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light_rgb/set', 'on', 2, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)
light.turn_off(self.hass, 'light.test') light.turn_off(self.hass, 'light.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'off', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light_rgb/set', 'off', 2, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state) self.assertEqual(STATE_OFF, state.state)
@ -512,7 +514,7 @@ class TestLightMQTT(unittest.TestCase):
white_value=80) white_value=80)
self.hass.block_till_done() self.hass.block_till_done()
self.mock_publish().async_publish.assert_has_calls([ self.mock_publish.async_publish.assert_has_calls([
mock.call('test_light_rgb/set', 'on', 2, False), mock.call('test_light_rgb/set', 'on', 2, False),
mock.call('test_light_rgb/rgb/set', '75,75,75', 2, False), mock.call('test_light_rgb/rgb/set', '75,75,75', 2, False),
mock.call('test_light_rgb/brightness/set', 50, 2, False), mock.call('test_light_rgb/brightness/set', 50, 2, False),
@ -550,7 +552,7 @@ class TestLightMQTT(unittest.TestCase):
light.turn_on(self.hass, 'light.test', rgb_color=[255, 128, 64]) light.turn_on(self.hass, 'light.test', rgb_color=[255, 128, 64])
self.hass.block_till_done() self.hass.block_till_done()
self.mock_publish().async_publish.assert_has_calls([ self.mock_publish.async_publish.assert_has_calls([
mock.call('test_light_rgb/set', 'on', 0, False), mock.call('test_light_rgb/set', 'on', 0, False),
mock.call('test_light_rgb/rgb/set', '#ff8040', 0, False), mock.call('test_light_rgb/rgb/set', '#ff8040', 0, False),
], any_order=True) ], any_order=True)
@ -701,16 +703,17 @@ class TestLightMQTT(unittest.TestCase):
# Should get the following MQTT messages. # Should get the following MQTT messages.
# test_light/set: 'ON' # test_light/set: 'ON'
# test_light/bright: 50 # test_light/bright: 50
self.assertEqual(('test_light/set', 'ON', 0, False), self.mock_publish.async_publish.assert_has_calls([
self.mock_publish.mock_calls[-4][1]) mock.call('test_light/set', 'ON', 0, False),
self.assertEqual(('test_light/bright', 50, 0, False), mock.call('test_light/bright', 50, 0, False),
self.mock_publish.mock_calls[-2][1]) ], any_order=True)
self.mock_publish.async_publish.reset_mock()
light.turn_off(self.hass, 'light.test') light.turn_off(self.hass, 'light.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light/set', 'OFF', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light/set', 'OFF', 0, False)
def test_on_command_last(self): def test_on_command_last(self):
"""Test on command being sent after brightness.""" """Test on command being sent after brightness."""
@ -733,16 +736,17 @@ class TestLightMQTT(unittest.TestCase):
# Should get the following MQTT messages. # Should get the following MQTT messages.
# test_light/bright: 50 # test_light/bright: 50
# test_light/set: 'ON' # test_light/set: 'ON'
self.assertEqual(('test_light/bright', 50, 0, False), self.mock_publish.async_publish.assert_has_calls([
self.mock_publish.mock_calls[-4][1]) mock.call('test_light/bright', 50, 0, False),
self.assertEqual(('test_light/set', 'ON', 0, False), mock.call('test_light/set', 'ON', 0, False),
self.mock_publish.mock_calls[-2][1]) ], any_order=True)
self.mock_publish.async_publish.reset_mock()
light.turn_off(self.hass, 'light.test') light.turn_off(self.hass, 'light.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light/set', 'OFF', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light/set', 'OFF', 0, False)
def test_on_command_brightness(self): def test_on_command_brightness(self):
"""Test on command being sent as only brightness.""" """Test on command being sent as only brightness."""
@ -767,21 +771,24 @@ class TestLightMQTT(unittest.TestCase):
# Should get the following MQTT messages. # Should get the following MQTT messages.
# test_light/bright: 255 # test_light/bright: 255
self.assertEqual(('test_light/bright', 255, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light/bright', 255, 0, False)
self.mock_publish.async_publish.reset_mock()
light.turn_off(self.hass, 'light.test') light.turn_off(self.hass, 'light.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light/set', 'OFF', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light/set', 'OFF', 0, False)
self.mock_publish.async_publish.reset_mock()
# Turn on w/ brightness # Turn on w/ brightness
light.turn_on(self.hass, 'light.test', brightness=50) light.turn_on(self.hass, 'light.test', brightness=50)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light/bright', 50, 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'test_light/bright', 50, 0, False)
self.mock_publish.async_publish.reset_mock()
light.turn_off(self.hass, 'light.test') light.turn_off(self.hass, 'light.test')
self.hass.block_till_done() self.hass.block_till_done()
@ -791,10 +798,10 @@ class TestLightMQTT(unittest.TestCase):
light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75]) light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75])
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light/rgb', '75,75,75', 0, False), self.mock_publish.async_publish.assert_has_calls([
self.mock_publish.mock_calls[-4][1]) mock.call('test_light/rgb', '75,75,75', 0, False),
self.assertEqual(('test_light/bright', 50, 0, False), mock.call('test_light/bright', 50, 0, False)
self.mock_publish.mock_calls[-2][1]) ], any_order=True)
def test_default_availability_payload(self): def test_default_availability_payload(self):
"""Test availability by default payload with defined topic.""" """Test availability by default payload with defined topic."""

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -70,16 +70,17 @@ class TestLockMQTT(unittest.TestCase):
lock.lock(self.hass, 'lock.test') lock.lock(self.hass, 'lock.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'LOCK', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'LOCK', 2, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get('lock.test') state = self.hass.states.get('lock.test')
self.assertEqual(STATE_LOCKED, state.state) self.assertEqual(STATE_LOCKED, state.state)
lock.unlock(self.hass, 'lock.test') lock.unlock(self.hass, 'lock.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'UNLOCK', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'UNLOCK', 2, False)
state = self.hass.states.get('lock.test') state = self.hass.states.get('lock.test')
self.assertEqual(STATE_UNLOCKED, state.state) self.assertEqual(STATE_UNLOCKED, state.state)

View file

@ -18,7 +18,7 @@ def test_subscribing_config_topic(hass, mqtt_mock):
assert mqtt_mock.async_subscribe.called assert mqtt_mock.async_subscribe.called
call_args = mqtt_mock.async_subscribe.mock_calls[0][1] call_args = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args[0] == discovery_topic + '/#' assert call_args[0] == discovery_topic + '/#'
assert call_args[1] == 0 assert call_args[2] == 0
@asyncio.coroutine @asyncio.coroutine

View file

@ -1,6 +1,5 @@
"""The tests for the MQTT component.""" """The tests for the MQTT component."""
import asyncio import asyncio
from collections import namedtuple, OrderedDict
import unittest import unittest
from unittest import mock from unittest import mock
import socket import socket
@ -9,26 +8,27 @@ import ssl
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
from homeassistant.const import ( from homeassistant.const import (EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE,
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_STOP) EVENT_HOMEASSISTANT_STOP)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from tests.common import ( from tests.common import (get_test_home_assistant, mock_coro,
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro) mock_mqtt_component,
threadsafe_coroutine_factory, fire_mqtt_message,
async_fire_mqtt_message)
@asyncio.coroutine @asyncio.coroutine
def mock_mqtt_client(hass, config=None): def async_mock_mqtt_client(hass, config=None):
"""Mock the MQTT paho client.""" """Mock the MQTT paho client."""
if config is None: if config is None:
config = { config = {mqtt.CONF_BROKER: 'mock-broker'}
mqtt.CONF_BROKER: 'mock-broker'
}
with mock.patch('paho.mqtt.client.Client') as mock_client: with mock.patch('paho.mqtt.client.Client') as mock_client:
mock_client().connect = lambda *args: 0 mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
result = yield from async_setup_component(hass, mqtt.DOMAIN, { result = yield from async_setup_component(hass, mqtt.DOMAIN, {
mqtt.DOMAIN: config mqtt.DOMAIN: config
}) })
@ -36,8 +36,11 @@ def mock_mqtt_client(hass, config=None):
return mock_client() return mock_client()
mock_mqtt_client = threadsafe_coroutine_factory(async_mock_mqtt_client)
# pylint: disable=invalid-name # pylint: disable=invalid-name
class TestMQTT(unittest.TestCase): class TestMQTTComponent(unittest.TestCase):
"""Test the MQTT component.""" """Test the MQTT component."""
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
@ -55,12 +58,8 @@ class TestMQTT(unittest.TestCase):
"""Helper for recording calls.""" """Helper for recording calls."""
self.calls.append(args) self.calls.append(args)
def test_client_starts_on_home_assistant_mqtt_setup(self):
"""Test if client is connect after mqtt init on bootstrap."""
assert self.hass.data['mqtt'].async_connect.called
def test_client_stops_on_home_assistant_start(self): def test_client_stops_on_home_assistant_start(self):
"""Test if client stops on HA launch.""" """Test if client stops on HA stop."""
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP) self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
self.hass.block_till_done() self.hass.block_till_done()
self.assertTrue(self.hass.data['mqtt'].async_disconnect.called) self.assertTrue(self.hass.data['mqtt'].async_disconnect.called)
@ -131,6 +130,48 @@ class TestMQTT(unittest.TestCase):
self.hass.data['mqtt'].async_publish.call_args[0][2], 2) self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3]) self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
def test_invalid_mqtt_topics(self):
"""Test invalid topics."""
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one')
# pylint: disable=invalid-name
class TestMQTTCallbacks(unittest.TestCase):
"""Test the MQTT callbacks."""
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_mqtt_client(self.hass)
self.calls = []
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
self.hass.stop()
@callback
def record_calls(self, *args):
"""Helper for recording calls."""
self.calls.append(args)
def test_client_starts_on_home_assistant_mqtt_setup(self):
"""Test if client is connected after mqtt init on bootstrap."""
self.assertEqual(self.hass.data['mqtt']._mqttc.connect.call_count, 1)
def test_receiving_non_utf8_message_gets_logged(self):
"""Test receiving a non utf8 encoded message."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
with self.assertLogs(level='WARNING') as test_handle:
fire_mqtt_message(self.hass, 'test-topic', b'\x9a')
self.hass.block_till_done()
self.assertIn(
"WARNING:homeassistant.components.mqtt:Can't decode payload "
"b'\\x9a' on test-topic with encoding utf-8",
test_handle.output[0])
def test_subscribe_topic(self): def test_subscribe_topic(self):
"""Test the subscription of a topic.""" """Test the subscription of a topic."""
unsub = mqtt.subscribe(self.hass, 'test-topic', self.record_calls) unsub = mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
@ -296,82 +337,6 @@ class TestMQTT(unittest.TestCase):
self.assertEqual(topic, self.calls[0][0]) self.assertEqual(topic, self.calls[0][0])
self.assertEqual(payload, self.calls[0][1]) self.assertEqual(payload, self.calls[0][1])
def test_subscribe_binary_topic(self):
"""Test the subscription to a binary topic."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls,
0, None)
fire_mqtt_message(self.hass, 'test-topic', 0x9a)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('test-topic', self.calls[0][0])
self.assertEqual(0x9a, self.calls[0][1])
def test_receiving_non_utf8_message_gets_logged(self):
"""Test receiving a non utf8 encoded message."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
with self.assertLogs(level='ERROR') as test_handle:
fire_mqtt_message(self.hass, 'test-topic', 0x9a)
self.hass.block_till_done()
self.assertIn(
"ERROR:homeassistant.components.mqtt:Illegal payload "
"encoding utf-8 from MQTT "
"topic: test-topic, Payload: 154",
test_handle.output[0])
class TestMQTTCallbacks(unittest.TestCase):
"""Test the MQTT callbacks."""
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
with mock.patch('paho.mqtt.client.Client') as client:
client().connect = lambda *args: 0
assert setup_component(self.hass, mqtt.DOMAIN, {
mqtt.DOMAIN: {
mqtt.CONF_BROKER: 'mock-broker',
}
})
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
self.hass.stop()
def test_receiving_mqtt_message_fires_hass_event(self):
"""Test if receiving triggers an event."""
calls = []
@callback
def record(topic, payload, qos):
"""Helper to record calls."""
data = {
'topic': topic,
'payload': payload,
'qos': qos,
}
calls.append(data)
async_dispatcher_connect(
self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record)
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
self.hass.data['mqtt']._mqtt_on_message(
None, {'hass': self.hass}, message)
self.hass.block_till_done()
self.assertEqual(1, len(calls))
last_event = calls[0]
self.assertEqual(bytearray('Hello World!', 'utf-8'),
last_event['payload'])
self.assertEqual(message.topic, last_event['topic'])
self.assertEqual(message.qos, last_event['qos'])
def test_mqtt_failed_connection_results_in_disconnect(self): def test_mqtt_failed_connection_results_in_disconnect(self):
"""Test if connection failure leads to disconnect.""" """Test if connection failure leads to disconnect."""
for result_code in range(1, 6): for result_code in range(1, 6):
@ -388,16 +353,11 @@ class TestMQTTCallbacks(unittest.TestCase):
@mock.patch('homeassistant.components.mqtt.time.sleep') @mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
"""Test the re-connect tries.""" """Test the re-connect tries."""
self.hass.data['mqtt'].subscribed_topics = { self.hass.data['mqtt'].subscriptions = [
'test/topic': 1, mqtt.Subscription('test/progress', None, 0),
} mqtt.Subscription('test/progress', None, 1),
self.hass.data['mqtt'].wanted_topics = { mqtt.Subscription('test/topic', None, 2),
'test/progress': 0, ]
'test/topic': 2,
}
self.hass.data['mqtt'].progress = {
1: 'test/progress'
}
self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0] self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0]
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1) self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1)
self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called) self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called)
@ -406,15 +366,77 @@ class TestMQTTCallbacks(unittest.TestCase):
self.assertEqual([1, 2, 4], self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls]) [call[1][0] for call in mock_sleep.mock_calls])
self.assertEqual({'test/topic': 2, 'test/progress': 0}, def test_retained_message_on_subscribe_received(self):
self.hass.data['mqtt'].wanted_topics) """Test every subscriber receives retained message on subscribe."""
self.assertEqual({}, self.hass.data['mqtt'].subscribed_topics) def side_effect(*args):
self.assertEqual({}, self.hass.data['mqtt'].progress) async_fire_mqtt_message(self.hass, 'test/state', 'online')
return 0, 0
def test_invalid_mqtt_topics(self): self.hass.data['mqtt']._mqttc.subscribe.side_effect = side_effect
"""Test invalid topics."""
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic') calls_a = mock.MagicMock()
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one') mqtt.subscribe(self.hass, 'test/state', calls_a)
self.hass.block_till_done()
self.assertTrue(calls_a.called)
calls_b = mock.MagicMock()
mqtt.subscribe(self.hass, 'test/state', calls_b)
self.hass.block_till_done()
self.assertTrue(calls_b.called)
def test_not_calling_unsubscribe_with_active_subscribers(self):
"""Test not calling unsubscribe() when other subscribers are active."""
unsub = mqtt.subscribe(self.hass, 'test/state', None)
mqtt.subscribe(self.hass, 'test/state', None)
self.hass.block_till_done()
self.assertTrue(self.hass.data['mqtt']._mqttc.subscribe.called)
unsub()
self.hass.block_till_done()
self.assertFalse(self.hass.data['mqtt']._mqttc.unsubscribe.called)
def test_restore_subscriptions_on_reconnect(self):
"""Test subscriptions are restored on reconnect."""
mqtt.subscribe(self.hass, 'test/state', None)
self.hass.block_till_done()
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.call_count, 1)
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.hass.data['mqtt']._mqtt_on_connect(None, None, None, 0)
self.hass.block_till_done()
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.call_count, 2)
def test_restore_all_active_subscriptions_on_reconnect(self):
"""Test active subscriptions are restored correctly on reconnect."""
self.hass.data['mqtt']._mqttc.subscribe.side_effect = (
(0, 1), (0, 2), (0, 3), (0, 4)
)
unsub = mqtt.subscribe(self.hass, 'test/state', None, qos=2)
mqtt.subscribe(self.hass, 'test/state', None)
mqtt.subscribe(self.hass, 'test/state', None, qos=1)
self.hass.block_till_done()
expected = [
mock.call('test/state', 2),
mock.call('test/state', 0),
mock.call('test/state', 1)
]
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.mock_calls,
expected)
unsub()
self.hass.block_till_done()
self.assertEqual(self.hass.data['mqtt']._mqttc.unsubscribe.call_count,
0)
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.hass.data['mqtt']._mqtt_on_connect(None, None, None, 0)
self.hass.block_till_done()
expected.append(mock.call('test/state', 1))
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.mock_calls,
expected)
@asyncio.coroutine @asyncio.coroutine
@ -426,7 +448,7 @@ def test_setup_embedded_starts_with_no_config(hass):
return_value=mock_coro( return_value=mock_coro(
return_value=(True, client_config)) return_value=(True, client_config))
) as _start: ) as _start:
yield from mock_mqtt_client(hass, {}) yield from async_mock_mqtt_client(hass, {})
assert _start.call_count == 1 assert _start.call_count == 1
@ -440,7 +462,7 @@ def test_setup_embedded_with_embedded(hass):
return_value=(True, client_config)) return_value=(True, client_config))
) as _start: ) as _start:
_start.return_value = mock_coro(return_value=(True, client_config)) _start.return_value = mock_coro(return_value=(True, client_config))
yield from mock_mqtt_client(hass, {'embedded': None}) yield from async_mock_mqtt_client(hass, {'embedded': None})
assert _start.call_count == 1 assert _start.call_count == 1
@ -544,13 +566,13 @@ def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass):
@asyncio.coroutine @asyncio.coroutine
def test_birth_message(hass): def test_birth_message(hass):
"""Test sending birth message.""" """Test sending birth message."""
mqtt_client = yield from mock_mqtt_client(hass, { mqtt_client = yield from async_mock_mqtt_client(hass, {
mqtt.CONF_BROKER: 'mock-broker', mqtt.CONF_BROKER: 'mock-broker',
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth', mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
mqtt.ATTR_PAYLOAD: 'birth'} mqtt.ATTR_PAYLOAD: 'birth'}
}) })
calls = [] calls = []
mqtt_client.publish = lambda *args: calls.append(args) mqtt_client.publish.side_effect = lambda *args: calls.append(args)
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0) hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done() yield from hass.async_block_till_done()
assert calls[-1] == ('birth', 'birth', 0, False) assert calls[-1] == ('birth', 'birth', 0, False)
@ -559,30 +581,26 @@ def test_birth_message(hass):
@asyncio.coroutine @asyncio.coroutine
def test_mqtt_subscribes_topics_on_connect(hass): def test_mqtt_subscribes_topics_on_connect(hass):
"""Test subscription to topic on connect.""" """Test subscription to topic on connect."""
mqtt_client = yield from mock_mqtt_client(hass) mqtt_client = yield from async_mock_mqtt_client(hass)
subscribed_topics = OrderedDict() hass.data['mqtt'].subscriptions = [
subscribed_topics['topic/test'] = 1 mqtt.Subscription('topic/test', None),
subscribed_topics['home/sensor'] = 2 mqtt.Subscription('home/sensor', None, 2),
mqtt.Subscription('still/pending', None),
wanted_topics = subscribed_topics.copy() mqtt.Subscription('still/pending', None, 1),
wanted_topics['still/pending'] = 0 ]
hass.data['mqtt'].wanted_topics = wanted_topics
hass.data['mqtt'].subscribed_topics = subscribed_topics
hass.data['mqtt'].progress = {1: 'still/pending'}
# Return values for subscribe calls (rc, mid)
mqtt_client.subscribe.side_effect = ((0, 2), (0, 3))
hass.add_job = mock.MagicMock() hass.add_job = mock.MagicMock()
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0) hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done() yield from hass.async_block_till_done()
assert not mqtt_client.disconnect.called assert mqtt_client.disconnect.call_count == 0
expected = [(topic, qos) for topic, qos in wanted_topics.items()] expected = {
'topic/test': 0,
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected 'home/sensor': 2,
assert hass.data['mqtt'].progress == {} 'still/pending': 1
}
calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls}
assert calls == expected

View file

@ -70,16 +70,17 @@ class TestSwitchMQTT(unittest.TestCase):
switch.turn_on(self.hass, 'switch.test') switch.turn_on(self.hass, 'switch.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'beer on', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'beer on', 2, False)
self.mock_publish.async_publish.reset_mock()
state = self.hass.states.get('switch.test') state = self.hass.states.get('switch.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)
switch.turn_off(self.hass, 'switch.test') switch.turn_off(self.hass, 'switch.test')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'beer off', 2, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'command-topic', 'beer off', 2, False)
state = self.hass.states.get('switch.test') state = self.hass.states.get('switch.test')
self.assertEqual(STATE_OFF, state.state) self.assertEqual(STATE_OFF, state.state)

View file

@ -71,52 +71,56 @@ class TestVacuumMQTT(unittest.TestCase):
vacuum.turn_on(self.hass, 'vacuum.mqtttest') vacuum.turn_on(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'turn_on', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'turn_on', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.turn_off(self.hass, 'vacuum.mqtttest') vacuum.turn_off(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'turn_off', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'turn_off', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.stop(self.hass, 'vacuum.mqtttest') vacuum.stop(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'stop', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'stop', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.clean_spot(self.hass, 'vacuum.mqtttest') vacuum.clean_spot(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'clean_spot', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'clean_spot', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.locate(self.hass, 'vacuum.mqtttest') vacuum.locate(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'locate', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'locate', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.start_pause(self.hass, 'vacuum.mqtttest') vacuum.start_pause(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'start_pause', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'start_pause', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.return_to_base(self.hass, 'vacuum.mqtttest') vacuum.return_to_base(self.hass, 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('vacuum/command', 'return_to_base', 0, False), self.mock_publish.async_publish.assert_called_once_with(
self.mock_publish.mock_calls[-2][1]) 'vacuum/command', 'return_to_base', 0, False)
self.mock_publish.async_publish.reset_mock()
vacuum.set_fan_speed(self.hass, 'high', 'vacuum.mqtttest') vacuum.set_fan_speed(self.hass, 'high', 'vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual( self.mock_publish.async_publish.assert_called_once_with(
('vacuum/set_fan_speed', 'high', 0, False), 'vacuum/set_fan_speed', 'high', 0, False)
self.mock_publish.mock_calls[-2][1] self.mock_publish.async_publish.reset_mock()
)
vacuum.send_command(self.hass, '44 FE 93', entity_id='vacuum.mqtttest') vacuum.send_command(self.hass, '44 FE 93', entity_id='vacuum.mqtttest')
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual( self.mock_publish.async_publish.assert_called_once_with(
('vacuum/send_command', '44 FE 93', 0, False), 'vacuum/send_command', '44 FE 93', 0, False)
self.mock_publish.mock_calls[-2][1]
)
def test_status(self): def test_status(self):
"""Test status updates from the vacuum.""" """Test status updates from the vacuum."""

View file

@ -8,11 +8,11 @@ from unittest.mock import patch, MagicMock
import pytest import pytest
import requests_mock as _requests_mock import requests_mock as _requests_mock
from homeassistant import util, setup from homeassistant import util
from homeassistant.util import location from homeassistant.util import location
from homeassistant.components import mqtt
from tests.common import async_test_home_assistant, mock_coro, INSTANCES from tests.common import async_test_home_assistant, INSTANCES, \
async_mock_mqtt_component
from tests.test_util.aiohttp import mock_aiohttp_client from tests.test_util.aiohttp import mock_aiohttp_client
from tests.mock.zwave import MockNetwork, MockOption from tests.mock.zwave import MockNetwork, MockOption
@ -85,17 +85,9 @@ def aioclient_mock():
@pytest.fixture @pytest.fixture
def mqtt_mock(loop, hass): def mqtt_mock(loop, hass):
"""Fixture to mock MQTT.""" """Fixture to mock MQTT."""
with patch('homeassistant.components.mqtt.MQTT') as mock_mqtt: client = loop.run_until_complete(async_mock_mqtt_component(hass))
mock_mqtt().async_connect.return_value = mock_coro(True) client.reset_mock()
assert loop.run_until_complete(setup.async_setup_component( return client
hass, mqtt.DOMAIN, {
mqtt.DOMAIN: {
mqtt.CONF_BROKER: 'mock-broker',
}
}))
client = mock_mqtt()
client.reset_mock()
return client
@pytest.fixture @pytest.fixture