MQTT convert to async (#6064)

* Migrate mqtt to async

* address paulus comment / convert it complet async

* adress paulus comment / remove future

* Automation triggers should be async

* Fix MQTT async calls

* Show that event helpers are callbacks

* Fix tests

* Lint
This commit is contained in:
Pascal Vizeli 2017-02-18 23:17:18 +01:00 committed by Paulus Schoutsen
parent fa2c1dafdf
commit e1cbd6b4c0
25 changed files with 356 additions and 231 deletions

View file

@ -412,7 +412,7 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
if platform is None: if platform is None:
return None return None
remove = platform.async_trigger(hass, conf, action) remove = yield from platform.async_trigger(hass, conf, action)
if not remove: if not remove:
_LOGGER.error("Error setting up trigger %s", name) _LOGGER.error("Error setting up trigger %s", name)

View file

@ -4,6 +4,7 @@ Offer event listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#event-trigger at https://home-assistant.io/components/automation/#event-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for events based on configuration.""" """Listen for events based on configuration."""
event_type = config.get(CONF_EVENT_TYPE) event_type = config.get(CONF_EVENT_TYPE)

View file

@ -4,6 +4,7 @@ Trigger an automation when a LiteJet switch is released.
For more details about this platform, please refer to the documentation at For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/automation.litejet/ https://home-assistant.io/components/automation.litejet/
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -32,6 +33,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for events based on configuration.""" """Listen for events based on configuration."""
number = config.get(CONF_NUMBER) number = config.get(CONF_NUMBER)

View file

@ -4,6 +4,7 @@ Offer MQTT listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#mqtt-trigger at https://home-assistant.io/components/automation/#mqtt-trigger
""" """
import asyncio
import json import json
import voluptuous as vol import voluptuous as vol
@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
topic = config.get(CONF_TOPIC) topic = config.get(CONF_TOPIC)
@ -49,4 +51,6 @@ def async_trigger(hass, config, action):
'trigger': data 'trigger': data
}) })
return mqtt.async_subscribe(hass, topic, mqtt_automation_listener) remove = yield from mqtt.async_subscribe(
hass, topic, mqtt_automation_listener)
return remove

View file

@ -4,6 +4,7 @@ Offer numeric state listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#numeric-state-trigger at https://home-assistant.io/components/automation/#numeric-state-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)

View file

@ -4,6 +4,7 @@ Offer state listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#state-trigger at https://home-assistant.io/components/automation/#state-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
@ -34,6 +35,7 @@ TRIGGER_SCHEMA = vol.All(
) )
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)
@ -97,6 +99,7 @@ def async_trigger(hass, config, action):
unsub = async_track_state_change( unsub = async_track_state_change(
hass, entity_id, state_automation_listener, from_state, to_state) hass, entity_id, state_automation_listener, from_state, to_state)
@callback
def async_remove(): def async_remove():
"""Remove state listeners async.""" """Remove state listeners async."""
unsub() unsub()

View file

@ -4,6 +4,7 @@ Offer sun based automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#sun-trigger at https://home-assistant.io/components/automation/#sun-trigger
""" """
import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for events based on configuration.""" """Listen for events based on configuration."""
event = config.get(CONF_EVENT) event = config.get(CONF_EVENT)

View file

@ -4,6 +4,7 @@ Offer template automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#template-trigger at https://home-assistant.io/components/automation/#template-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -22,6 +23,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema({
}) })
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
value_template = config.get(CONF_VALUE_TEMPLATE) value_template = config.get(CONF_VALUE_TEMPLATE)

View file

@ -4,6 +4,7 @@ Offer time listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#time-trigger at https://home-assistant.io/components/automation/#time-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@ -29,6 +30,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
CONF_SECONDS, CONF_AFTER)) CONF_SECONDS, CONF_AFTER))
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
if CONF_AFTER in config: if CONF_AFTER in config:

View file

@ -4,6 +4,7 @@ Offer zone automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#zone-trigger at https://home-assistant.io/components/automation/#zone-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({
}) })
@asyncio.coroutine
def async_trigger(hass, config, action): def async_trigger(hass, config, action):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)

View file

@ -4,6 +4,7 @@ Support for MQTT message handling.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/ https://home-assistant.io/components/mqtt/
""" """
import asyncio
import logging import logging
import os import os
import socket import socket
@ -12,11 +13,12 @@ import time
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.bootstrap import prepare_setup_platform from homeassistant.bootstrap import async_prepare_setup_platform
from homeassistant.config import load_yaml_config_file from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template, config_validation as cv from homeassistant.helpers import template, config_validation as cv
from homeassistant.helpers.event import threaded_listener_factory from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
from homeassistant.const import ( from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_VALUE_TEMPLATE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_VALUE_TEMPLATE,
CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_PAYLOAD) CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_PAYLOAD)
@ -26,7 +28,7 @@ _LOGGER = logging.getLogger(__name__)
DOMAIN = 'mqtt' DOMAIN = 'mqtt'
MQTT_CLIENT = None DATA_MQTT = 'mqtt'
SERVICE_PUBLISH = 'publish' SERVICE_PUBLISH = 'publish'
EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received' EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
@ -183,11 +185,11 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
hass.services.call(DOMAIN, SERVICE_PUBLISH, data) hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
@callback @asyncio.coroutine
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic.""" """Subscribe to an MQTT topic."""
@callback @callback
def mqtt_topic_subscriber(event): def async_mqtt_topic_subscriber(event):
"""Match subscribed MQTT topic.""" """Match subscribed MQTT topic."""
if not _match_topic(topic, event.data[ATTR_TOPIC]): if not _match_topic(topic, event.data[ATTR_TOPIC]):
return return
@ -195,61 +197,82 @@ def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
hass.async_run_job(msg_callback, event.data[ATTR_TOPIC], hass.async_run_job(msg_callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED, async_remove = hass.bus.async_listen(
mqtt_topic_subscriber) EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
# Future: track subscriber count and unsubscribe in remove
MQTT_CLIENT.subscribe(topic, qos)
yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
return async_remove return async_remove
# pylint: disable=invalid-name def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
subscribe = threaded_listener_factory(async_subscribe) """Subscribe to an MQTT topic."""
async_remove = run_coroutine_threadsafe(
async_subscribe(hass, topic, msg_callback, qos),
hass.loop
).result()
def remove():
"""Remove listener convert."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove
def _setup_server(hass, config): @asyncio.coroutine
"""Try to start embedded MQTT broker.""" def _async_setup_server(hass, config):
"""Try to start embedded MQTT broker.
This method is a coroutine.
"""
conf = config.get(DOMAIN, {}) conf = config.get(DOMAIN, {})
# Only setup if embedded config passed in or no broker specified # Only setup if embedded config passed in or no broker specified
if CONF_EMBEDDED not in conf and CONF_BROKER in conf: if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
return None return None
server = prepare_setup_platform(hass, config, DOMAIN, 'server') server = yield from async_prepare_setup_platform(
hass, config, DOMAIN, 'server')
if server is None: if server is None:
_LOGGER.error("Unable to load embedded server") _LOGGER.error("Unable to load embedded server")
return None return None
success, broker_config = server.start(hass, conf.get(CONF_EMBEDDED)) success, broker_config = \
yield from server.async_start(hass, conf.get(CONF_EMBEDDED))
return success and broker_config return success and broker_config
def _setup_discovery(hass, config): @asyncio.coroutine
"""Try to start the discovery of MQTT devices.""" def _async_setup_discovery(hass, config):
"""Try to start the discovery of MQTT devices.
This method is a coroutine.
"""
conf = config.get(DOMAIN, {}) conf = config.get(DOMAIN, {})
discovery = prepare_setup_platform(hass, config, DOMAIN, 'discovery') discovery = yield from async_prepare_setup_platform(
hass, config, DOMAIN, 'discovery')
if discovery is None: if discovery is None:
_LOGGER.error("Unable to load MQTT discovery") _LOGGER.error("Unable to load MQTT discovery")
return None return None
success = discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config) success = yield from discovery.async_start(
hass, conf[CONF_DISCOVERY_PREFIX], config)
return success return success
def setup(hass, config): @asyncio.coroutine
def async_setup(hass, config):
"""Start the MQTT protocol service.""" """Start the MQTT protocol service."""
conf = config.get(DOMAIN, {}) conf = config.get(DOMAIN, {})
client_id = conf.get(CONF_CLIENT_ID) client_id = conf.get(CONF_CLIENT_ID)
keepalive = conf.get(CONF_KEEPALIVE) keepalive = conf.get(CONF_KEEPALIVE)
broker_config = _setup_server(hass, config) broker_config = yield from _async_setup_server(hass, config)
if CONF_BROKER in conf: if CONF_BROKER in conf:
broker = conf[CONF_BROKER] broker = conf[CONF_BROKER]
@ -283,27 +306,31 @@ def setup(hass, config):
will_message = conf.get(CONF_WILL_MESSAGE) will_message = conf.get(CONF_WILL_MESSAGE)
birth_message = conf.get(CONF_BIRTH_MESSAGE) birth_message = conf.get(CONF_BIRTH_MESSAGE)
global MQTT_CLIENT
try: try:
MQTT_CLIENT = MQTT(hass, broker, port, client_id, keepalive, hass.data[DATA_MQTT] = MQTT(
username, password, certificate, client_key, hass, broker, port, client_id, keepalive, username, password,
client_cert, tls_insecure, protocol, will_message, certificate, client_key, client_cert, tls_insecure, protocol,
birth_message) will_message, birth_message)
except socket.error: except socket.error:
_LOGGER.exception("Can't connect to the broker. " _LOGGER.exception("Can't connect to the broker. "
"Please check your settings and the broker itself") "Please check your settings and the broker itself")
return False return False
def stop_mqtt(event): @asyncio.coroutine
def async_stop_mqtt(event):
"""Stop MQTT component.""" """Stop MQTT component."""
MQTT_CLIENT.stop() yield from hass.data[DATA_MQTT].async_stop()
def start_mqtt(event): @asyncio.coroutine
def async_start_mqtt(event):
"""Launch MQTT component when Home Assistant starts up.""" """Launch MQTT component when Home Assistant starts up."""
MQTT_CLIENT.start() yield from hass.data[DATA_MQTT].async_start()
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_mqtt) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
def publish_service(call): hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, async_start_mqtt)
@asyncio.coroutine
def async_publish_service(call):
"""Handle MQTT publish service calls.""" """Handle MQTT publish service calls."""
msg_topic = call.data[ATTR_TOPIC] msg_topic = call.data[ATTR_TOPIC]
payload = call.data.get(ATTR_PAYLOAD) payload = call.data.get(ATTR_PAYLOAD)
@ -312,26 +339,28 @@ def setup(hass, config):
retain = call.data[ATTR_RETAIN] retain = call.data[ATTR_RETAIN]
try: try:
if payload_template is not None: if payload_template is not None:
payload = template.Template(payload_template, hass).render() payload = \
template.Template(payload_template, hass).async_render()
except template.jinja2.TemplateError as exc: except template.jinja2.TemplateError as exc:
_LOGGER.error( _LOGGER.error(
"Unable to publish to '%s': rendering payload template of " "Unable to publish to '%s': rendering payload template of "
"'%s' failed because %s", "'%s' failed because %s",
msg_topic, payload_template, exc) msg_topic, payload_template, exc)
return return
MQTT_CLIENT.publish(msg_topic, payload, qos, retain)
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_mqtt) yield from hass.data[DATA_MQTT].async_publish(
msg_topic, payload, qos, retain)
descriptions = load_yaml_config_file( descriptions = yield from hass.loop.run_in_executor(
os.path.join(os.path.dirname(__file__), 'services.yaml')) None, load_yaml_config_file, os.path.join(
os.path.dirname(__file__), 'services.yaml'))
hass.services.register(DOMAIN, SERVICE_PUBLISH, publish_service, hass.services.async_register(
descriptions.get(SERVICE_PUBLISH), DOMAIN, SERVICE_PUBLISH, async_publish_service,
schema=MQTT_PUBLISH_SCHEMA) descriptions.get(SERVICE_PUBLISH), schema=MQTT_PUBLISH_SCHEMA)
if conf.get(CONF_DISCOVERY): if conf.get(CONF_DISCOVERY):
_setup_discovery(hass, config) yield from _async_setup_discovery(hass, config)
return True return True
@ -349,6 +378,7 @@ class MQTT(object):
self.topics = {} self.topics = {}
self.progress = {} self.progress = {}
self.birth_message = birth_message self.birth_message = birth_message
self._mqttc = None
if protocol == PROTOCOL_31: if protocol == PROTOCOL_31:
proto = mqtt.MQTTv31 proto = mqtt.MQTTv31
@ -364,8 +394,8 @@ class MQTT(object):
self._mqttc.username_pw_set(username, password) self._mqttc.username_pw_set(username, password)
if certificate is not None: if certificate is not None:
self._mqttc.tls_set(certificate, certfile=client_cert, self._mqttc.tls_set(
keyfile=client_key) certificate, certfile=client_cert, keyfile=client_key)
if tls_insecure is not None: if tls_insecure is not None:
self._mqttc.tls_insecure_set(tls_insecure) self._mqttc.tls_insecure_set(tls_insecure)
@ -375,40 +405,69 @@ class MQTT(object):
self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message self._mqttc.on_message = self._mqtt_on_message
if will_message: if will_message:
self._mqttc.will_set(will_message.get(ATTR_TOPIC), self._mqttc.will_set(will_message.get(ATTR_TOPIC),
will_message.get(ATTR_PAYLOAD), will_message.get(ATTR_PAYLOAD),
will_message.get(ATTR_QOS), will_message.get(ATTR_QOS),
will_message.get(ATTR_RETAIN)) will_message.get(ATTR_RETAIN))
self._mqttc.connect(broker, port, keepalive)
def publish(self, topic, payload, qos, retain): self._mqttc.connect_async(broker, port, keepalive)
"""Publish a MQTT message."""
self._mqttc.publish(topic, payload, qos, retain)
def start(self): def async_publish(self, topic, payload, qos, retain):
"""Run the MQTT client.""" """Publish a MQTT message.
self._mqttc.loop_start()
def stop(self): This method must be run in the event loop and returns a coroutine.
"""Stop the MQTT client.""" """
self._mqttc.disconnect() return self.hass.loop.run_in_executor(
self._mqttc.loop_stop() None, self._mqttc.publish, topic, payload, qos, retain)
def subscribe(self, topic, qos): def async_start(self):
"""Subscribe to a topic.""" """Run the MQTT client.
assert isinstance(topic, str)
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(None, self._mqttc.loop_start)
def async_stop(self):
"""Stop the MQTT client.
This method must be run in the event loop and returns a coroutine.
"""
def stop(self):
"""Stop the MQTT client."""
self._mqttc.disconnect()
self._mqttc.loop_stop()
return self.hass.loop.run_in_executor(None, stop)
@asyncio.coroutine
def async_subscribe(self, topic, qos):
"""Subscribe to a topic.
This method is a coroutine.
"""
if not isinstance(topic, str):
raise HomeAssistantError("topic need to be a string!")
if topic in self.topics: if topic in self.topics:
return return
result, mid = self._mqttc.subscribe(topic, qos) result, mid = yield from self.hass.loop.run_in_executor(
None, self._mqttc.subscribe, topic, qos)
_raise_on_error(result) _raise_on_error(result)
self.progress[mid] = topic self.progress[mid] = topic
self.topics[topic] = None self.topics[topic] = None
def unsubscribe(self, topic): @asyncio.coroutine
"""Unsubscribe from topic.""" def async_unsubscribe(self, topic):
result, mid = self._mqttc.unsubscribe(topic) """Unsubscribe from topic.
This method is a coroutine.
"""
result, mid = yield from self.hass.loop.run_in_executor(
None, self._mqttc.unsubscribe, topic)
_raise_on_error(result) _raise_on_error(result)
self.progress[mid] = topic self.progress[mid] = topic
@ -437,12 +496,14 @@ class MQTT(object):
for topic, qos in old_topics.items(): for topic, qos in old_topics.items():
# qos is None if we were in process of subscribing # qos is None if we were in process of subscribing
if qos is not None: if qos is not None:
self.subscribe(topic, qos) self.hass.add_job(self.async_subscribe, topic, qos)
if self.birth_message: if self.birth_message:
self.publish(self.birth_message.get(ATTR_TOPIC), self.hass.add_job(self.async_publish(
self.birth_message.get(ATTR_PAYLOAD), self.birth_message.get(ATTR_TOPIC),
self.birth_message.get(ATTR_QOS), self.birth_message.get(ATTR_PAYLOAD),
self.birth_message.get(ATTR_RETAIN)) self.birth_message.get(ATTR_QOS),
self.birth_message.get(ATTR_RETAIN)))
def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos): def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
"""Subscribe successful callback.""" """Subscribe successful callback."""

View file

@ -9,7 +9,6 @@ import json
import logging import logging
import re import re
from homeassistant.core import callback
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
from homeassistant.components.mqtt import DOMAIN from homeassistant.components.mqtt import DOMAIN
from homeassistant.helpers.discovery import async_load_platform from homeassistant.helpers.discovery import async_load_platform
@ -24,7 +23,7 @@ TOPIC_MATCHER = re.compile(
SUPPORTED_COMPONENTS = ['binary_sensor', 'sensor'] SUPPORTED_COMPONENTS = ['binary_sensor', 'sensor']
@callback @asyncio.coroutine
def async_start(hass, discovery_topic, hass_config): def async_start(hass, discovery_topic, hass_config):
"""Initialization of MQTT Discovery.""" """Initialization of MQTT Discovery."""
@asyncio.coroutine @asyncio.coroutine
@ -56,7 +55,7 @@ def async_start(hass, discovery_topic, hass_config):
yield from async_load_platform( yield from async_load_platform(
hass, component, DOMAIN, payload, hass_config) hass, component, DOMAIN, payload, hass_config)
mqtt.async_subscribe(hass, discovery_topic + '/#', yield from mqtt.async_subscribe(
async_device_message_received, 0) hass, discovery_topic + '/#', async_device_message_received, 0)
return True return True

View file

@ -4,15 +4,14 @@ Support for a local MQTT broker.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/#use-the-embedded-broker https://home-assistant.io/components/mqtt/#use-the-embedded-broker
""" """
import asyncio
import logging import logging
import tempfile import tempfile
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
REQUIREMENTS = ['hbmqtt==0.8'] REQUIREMENTS = ['hbmqtt==0.8']
DEPENDENCIES = ['http'] DEPENDENCIES = ['http']
@ -29,8 +28,12 @@ HBMQTT_CONFIG_SCHEMA = vol.Any(None, vol.Schema({
}, extra=vol.ALLOW_EXTRA)) }, extra=vol.ALLOW_EXTRA))
def start(hass, server_config): @asyncio.coroutine
"""Initialize MQTT Server.""" def async_start(hass, server_config):
"""Initialize MQTT Server.
This method is a coroutine.
"""
from hbmqtt.broker import Broker, BrokerException from hbmqtt.broker import Broker, BrokerException
try: try:
@ -42,19 +45,20 @@ def start(hass, server_config):
client_config = None client_config = None
broker = Broker(server_config, hass.loop) broker = Broker(server_config, hass.loop)
run_coroutine_threadsafe(broker.start(), hass.loop).result() yield from broker.start()
except BrokerException: except BrokerException:
logging.getLogger(__name__).exception('Error initializing MQTT server') logging.getLogger(__name__).exception('Error initializing MQTT server')
return False, None return False, None
finally: finally:
passwd.close() passwd.close()
@callback @asyncio.coroutine
def shutdown_mqtt_server(event): def async_shutdown_mqtt_server(event):
"""Shut down the MQTT server.""" """Shut down the MQTT server."""
hass.async_add_job(broker.shutdown()) yield from broker.shutdown()
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, shutdown_mqtt_server) hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, async_shutdown_mqtt_server)
return True, client_config return True, client_config

View file

@ -34,6 +34,7 @@ def threaded_listener_factory(async_factory):
return factory return factory
@callback
def async_track_state_change(hass, entity_ids, action, from_state=None, def async_track_state_change(hass, entity_ids, action, from_state=None,
to_state=None): to_state=None):
"""Track specific state changes. """Track specific state changes.
@ -84,6 +85,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
track_state_change = threaded_listener_factory(async_track_state_change) track_state_change = threaded_listener_factory(async_track_state_change)
@callback
def async_track_template(hass, template, action, variables=None): def async_track_template(hass, template, action, variables=None):
"""Add a listener that track state changes with template condition.""" """Add a listener that track state changes with template condition."""
from . import condition from . import condition
@ -111,6 +113,7 @@ def async_track_template(hass, template, action, variables=None):
track_template = threaded_listener_factory(async_track_template) track_template = threaded_listener_factory(async_track_template)
@callback
def async_track_point_in_time(hass, action, point_in_time): def async_track_point_in_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in time.""" """Add a listener that fires once after a specific point in time."""
utc_point_in_time = dt_util.as_utc(point_in_time) utc_point_in_time = dt_util.as_utc(point_in_time)
@ -127,6 +130,7 @@ def async_track_point_in_time(hass, action, point_in_time):
track_point_in_time = threaded_listener_factory(async_track_point_in_time) track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@callback
def async_track_point_in_utc_time(hass, action, point_in_time): def async_track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time.""" """Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC # Ensure point_in_time is UTC
@ -160,6 +164,7 @@ track_point_in_utc_time = threaded_listener_factory(
async_track_point_in_utc_time) async_track_point_in_utc_time)
@callback
def async_track_time_interval(hass, action, interval): def async_track_time_interval(hass, action, interval):
"""Add a listener that fires repetitively at every timedelta interval.""" """Add a listener that fires repetitively at every timedelta interval."""
remove = None remove = None
@ -189,6 +194,7 @@ def async_track_time_interval(hass, action, interval):
track_time_interval = threaded_listener_factory(async_track_time_interval) track_time_interval = threaded_listener_factory(async_track_time_interval)
@callback
def async_track_sunrise(hass, action, offset=None): def async_track_sunrise(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunrise daily.""" """Add a listener that will fire a specified offset from sunrise daily."""
from homeassistant.components import sun from homeassistant.components import sun
@ -225,6 +231,7 @@ def async_track_sunrise(hass, action, offset=None):
track_sunrise = threaded_listener_factory(async_track_sunrise) track_sunrise = threaded_listener_factory(async_track_sunrise)
@callback
def async_track_sunset(hass, action, offset=None): def async_track_sunset(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunset daily.""" """Add a listener that will fire a specified offset from sunset daily."""
from homeassistant.components import sun from homeassistant.components import sun
@ -261,6 +268,7 @@ def async_track_sunset(hass, action, offset=None):
track_sunset = threaded_listener_factory(async_track_sunset) track_sunset = threaded_listener_factory(async_track_sunset)
@callback
def async_track_utc_time_change(hass, action, year=None, month=None, day=None, def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None, hour=None, minute=None, second=None,
local=False): local=False):
@ -305,6 +313,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
track_utc_time_change = threaded_listener_factory(async_track_utc_time_change) track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
@callback
def async_track_time_change(hass, action, year=None, month=None, day=None, def async_track_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None): hour=None, minute=None, second=None):
"""Add a listener that will fire if UTC time matches a pattern.""" """Add a listener that will fire if UTC time matches a pattern."""

View file

@ -86,7 +86,13 @@ def async_test_home_assistant(loop):
loop._thread_ident = threading.get_ident() loop._thread_ident = threading.get_ident()
hass = ha.HomeAssistant(loop) hass = ha.HomeAssistant(loop)
hass.async_track_tasks()
def async_add_job(target, *args):
if isinstance(target, MagicMock):
return
hass._async_add_job_tracking(target, *args)
hass.async_add_job = async_add_job
hass.config.location_name = 'test home' hass.config.location_name = 'test home'
hass.config.config_dir = get_test_config_dir() hass.config.config_dir = get_test_config_dir()

View file

@ -111,7 +111,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_arm_home(self.hass) alarm_control_panel.alarm_arm_home(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/command', 'ARM_HOME', 0, False), self.assertEqual(('alarm/command', 'ARM_HOME', 0, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
def test_arm_home_not_publishes_mqtt_with_invalid_code(self): def test_arm_home_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code.""" """Test not publishing of MQTT messages with invalid code."""
@ -146,7 +146,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_arm_away(self.hass) alarm_control_panel.alarm_arm_away(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False), self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
def test_arm_away_not_publishes_mqtt_with_invalid_code(self): def test_arm_away_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code.""" """Test not publishing of MQTT messages with invalid code."""
@ -181,7 +181,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_disarm(self.hass) alarm_control_panel.alarm_disarm(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('alarm/command', 'DISARM', 0, False), self.assertEqual(('alarm/command', 'DISARM', 0, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
def test_disarm_not_publishes_mqtt_with_invalid_code(self): def test_disarm_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code.""" """Test not publishing of MQTT messages with invalid code."""

View file

@ -118,7 +118,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'OPEN', 0, False), self.assertEqual(('command-topic', 'OPEN', 0, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_OPEN, state.state) self.assertEqual(STATE_OPEN, state.state)
@ -126,7 +126,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'CLOSE', 0, False), self.assertEqual(('command-topic', 'CLOSE', 0, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_CLOSED, state.state) self.assertEqual(STATE_CLOSED, state.state)
@ -150,7 +150,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'OPEN', 2, False), self.assertEqual(('command-topic', 'OPEN', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state) self.assertEqual(STATE_UNKNOWN, state.state)
@ -174,7 +174,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'CLOSE', 2, False), self.assertEqual(('command-topic', 'CLOSE', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state) self.assertEqual(STATE_UNKNOWN, state.state)
@ -198,7 +198,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'STOP', 2, False), self.assertEqual(('command-topic', 'STOP', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test') state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state) self.assertEqual(STATE_UNKNOWN, state.state)

View file

@ -74,6 +74,7 @@ light:
""" """
import unittest import unittest
from unittest import mock
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ASSUMED_STATE from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ASSUMED_STATE
@ -328,7 +329,7 @@ class TestLightMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'on', 2, False), self.assertEqual(('test_light_rgb/set', 'on', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)
@ -336,27 +337,20 @@ class TestLightMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'off', 2, False), self.assertEqual(('test_light_rgb/set', 'off', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state) self.assertEqual(STATE_OFF, state.state)
self.mock_publish.reset_mock()
light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75], light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75],
brightness=50) brightness=50)
self.hass.block_till_done() self.hass.block_till_done()
# Calls are threaded so we need to reorder them self.mock_publish().async_publish.assert_has_calls([
bright_call, rgb_call, state_call = \ mock.call('test_light_rgb/set', 'on', 2, False),
sorted((call[1] for call in self.mock_publish.mock_calls[-3:]), mock.call('test_light_rgb/rgb/set', '75,75,75', 2, False),
key=lambda call: call[0]) mock.call('test_light_rgb/brightness/set', 50, 2, False),
], any_order=True)
self.assertEqual(('test_light_rgb/set', 'on', 2, False),
state_call)
self.assertEqual(('test_light_rgb/rgb/set', '75,75,75', 2, False),
rgb_call)
self.assertEqual(('test_light_rgb/brightness/set', 50, 2, False),
bright_call)
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)

View file

@ -172,7 +172,7 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', '{"state": "ON"}', 2, False), self.assertEqual(('test_light_rgb/set', '{"state": "ON"}', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)
@ -180,7 +180,7 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', '{"state": "OFF"}', 2, False), self.assertEqual(('test_light_rgb/set', '{"state": "OFF"}', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state) self.assertEqual(STATE_OFF, state.state)
@ -189,11 +189,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message # Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(50, message_json["brightness"]) self.assertEqual(50, message_json["brightness"])
self.assertEqual(75, message_json["color"]["r"]) self.assertEqual(75, message_json["color"]["r"])
self.assertEqual(75, message_json["color"]["g"]) self.assertEqual(75, message_json["color"]["g"])
@ -228,11 +228,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message # Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(5, message_json["flash"]) self.assertEqual(5, message_json["flash"])
self.assertEqual("ON", message_json["state"]) self.assertEqual("ON", message_json["state"])
@ -240,11 +240,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message # Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(15, message_json["flash"]) self.assertEqual(15, message_json["flash"])
self.assertEqual("ON", message_json["state"]) self.assertEqual("ON", message_json["state"])
@ -268,11 +268,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message # Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(10, message_json["transition"]) self.assertEqual(10, message_json["transition"])
self.assertEqual("ON", message_json["state"]) self.assertEqual("ON", message_json["state"])
@ -281,11 +281,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message # Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(10, message_json["transition"]) self.assertEqual(10, message_json["transition"])
self.assertEqual("OFF", message_json["state"]) self.assertEqual("OFF", message_json["state"])

View file

@ -196,7 +196,7 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'on,,--', 2, False), self.assertEqual(('test_light_rgb/set', 'on,,--', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)
@ -205,7 +205,7 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'off', 2, False), self.assertEqual(('test_light_rgb/set', 'off', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test') state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state) self.assertEqual(STATE_OFF, state.state)
@ -215,12 +215,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload # check the payload
payload = self.mock_publish.mock_calls[-1][1][1] payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,50,75-75-75', payload) self.assertEqual('on,50,75-75-75', payload)
# check the state # check the state
@ -253,12 +253,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload # check the payload
payload = self.mock_publish.mock_calls[-1][1][1] payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,short', payload) self.assertEqual('on,short', payload)
# long flash # long flash
@ -266,12 +266,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload # check the payload
payload = self.mock_publish.mock_calls[-1][1][1] payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,long', payload) self.assertEqual('on,long', payload)
def test_transition(self): def test_transition(self):
@ -296,12 +296,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload # check the payload
payload = self.mock_publish.mock_calls[-1][1][1] payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,10', payload) self.assertEqual('on,10', payload)
# transition off # transition off
@ -309,12 +309,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual('test_light_rgb/set', self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0]) self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload # check the payload
payload = self.mock_publish.mock_calls[-1][1][1] payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('off,4', payload) self.assertEqual('off,4', payload)
def test_invalid_values(self): \ def test_invalid_values(self): \

View file

@ -73,7 +73,7 @@ class TestLockMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'LOCK', 2, False), self.assertEqual(('command-topic', 'LOCK', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('lock.test') state = self.hass.states.get('lock.test')
self.assertEqual(STATE_LOCKED, state.state) self.assertEqual(STATE_LOCKED, state.state)
@ -81,7 +81,7 @@ class TestLockMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'UNLOCK', 2, False), self.assertEqual(('command-topic', 'UNLOCK', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('lock.test') state = self.hass.states.get('lock.test')
self.assertEqual(STATE_UNLOCKED, state.state) self.assertEqual(STATE_UNLOCKED, state.state)

View file

@ -12,9 +12,10 @@ def test_subscribing_config_topic(hass, mqtt_mock):
"""Test setting up discovery.""" """Test setting up discovery."""
hass_config = {} hass_config = {}
discovery_topic = 'homeassistant' discovery_topic = 'homeassistant'
async_start(hass, discovery_topic, hass_config) yield from async_start(hass, discovery_topic, hass_config)
assert mqtt_mock.subscribe.called
call_args = mqtt_mock.subscribe.mock_calls[0][1] assert mqtt_mock.async_subscribe.called
call_args = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args[0] == discovery_topic + '/#' assert call_args[0] == discovery_topic + '/#'
assert call_args[1] == 0 assert call_args[1] == 0
@ -24,7 +25,7 @@ def test_subscribing_config_topic(hass, mqtt_mock):
def test_invalid_topic(mock_load_platform, hass, mqtt_mock): def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
"""Test sending in invalid JSON.""" """Test sending in invalid JSON."""
mock_load_platform.return_value = mock_coro() mock_load_platform.return_value = mock_coro()
async_start(hass, 'homeassistant', {}) yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config', async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config',
'{}') '{}')
@ -37,7 +38,7 @@ def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog): def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
"""Test sending in invalid JSON.""" """Test sending in invalid JSON."""
mock_load_platform.return_value = mock_coro() mock_load_platform.return_value = mock_coro()
async_start(hass, 'homeassistant', {}) yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config', async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
'not json') 'not json')
@ -51,7 +52,7 @@ def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog): def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
"""Test sending in invalid JSON.""" """Test sending in invalid JSON."""
mock_load_platform.return_value = mock_coro() mock_load_platform.return_value = mock_coro()
async_start(hass, 'homeassistant', {}) yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}') async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}')
yield from hass.async_block_till_done() yield from hass.async_block_till_done()
@ -62,7 +63,7 @@ def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
@asyncio.coroutine @asyncio.coroutine
def test_correct_config_discovery(hass, mqtt_mock, caplog): def test_correct_config_discovery(hass, mqtt_mock, caplog):
"""Test sending in invalid JSON.""" """Test sending in invalid JSON."""
async_start(hass, 'homeassistant', {}) yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config', async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
'{ "name": "Beer" }') '{ "name": "Beer" }')

View file

@ -1,5 +1,6 @@
"""The tests for the MQTT component.""" """The tests for the MQTT component."""
from collections import namedtuple import asyncio
from collections import namedtuple, OrderedDict
import unittest import unittest
from unittest import mock from unittest import mock
import socket import socket
@ -7,14 +8,29 @@ import socket
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component, async_setup_component
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
from homeassistant.const import ( from homeassistant.const import (
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START, EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STOP) EVENT_HOMEASSISTANT_STOP)
from tests.common import ( from tests.common import (
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message) get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro)
@asyncio.coroutine
def mock_mqtt_client(hass, config=None):
"""Mock the MQTT paho client."""
if config is None:
config = {
mqtt.CONF_BROKER: 'mock-broker'
}
with mock.patch('paho.mqtt.client.Client') as mock_client:
yield from async_setup_component(hass, mqtt.DOMAIN, {
mqtt.DOMAIN: config
})
return mock_client()
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -40,7 +56,7 @@ class TestMQTT(unittest.TestCase):
""""Test if client start on HA launch.""" """"Test if client start on HA launch."""
self.hass.bus.fire(EVENT_HOMEASSISTANT_START) self.hass.bus.fire(EVENT_HOMEASSISTANT_START)
self.hass.block_till_done() self.hass.block_till_done()
self.assertTrue(mqtt.MQTT_CLIENT.start.called) self.assertTrue(self.hass.data['mqtt'].async_start.called)
def test_client_stops_on_home_assistant_start(self): def test_client_stops_on_home_assistant_start(self):
"""Test if client stops on HA launch.""" """Test if client stops on HA launch."""
@ -48,7 +64,7 @@ class TestMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP) self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
self.hass.block_till_done() self.hass.block_till_done()
self.assertTrue(mqtt.MQTT_CLIENT.stop.called) self.assertTrue(self.hass.data['mqtt'].async_stop.called)
@mock.patch('paho.mqtt.client.Client') @mock.patch('paho.mqtt.client.Client')
def test_setup_fails_if_no_connect_broker(self, _): def test_setup_fails_if_no_connect_broker(self, _):
@ -69,14 +85,17 @@ class TestMQTT(unittest.TestCase):
"""Test setting up embedded server with no config.""" """Test setting up embedded server with no config."""
client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1') client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1')
with mock.patch('homeassistant.components.mqtt.server.start', with mock.patch('homeassistant.components.mqtt.server.async_start',
return_value=(True, client_config)) as _start: return_value=mock_coro(
return_value=(True, client_config))
) as _start:
self.hass.config.components = set() self.hass.config.components = set()
assert setup_component(self.hass, mqtt.DOMAIN, assert setup_component(self.hass, mqtt.DOMAIN,
{mqtt.DOMAIN: {}}) {mqtt.DOMAIN: {}})
assert _start.call_count == 1 assert _start.call_count == 1
# Test with `embedded: None` # Test with `embedded: None`
_start.return_value = mock_coro(return_value=(True, client_config))
self.hass.config.components = set() self.hass.config.components = set()
assert setup_component(self.hass, mqtt.DOMAIN, assert setup_component(self.hass, mqtt.DOMAIN,
{mqtt.DOMAIN: {'embedded': None}}) {mqtt.DOMAIN: {'embedded': None}})
@ -105,7 +124,7 @@ class TestMQTT(unittest.TestCase):
ATTR_SERVICE: mqtt.SERVICE_PUBLISH ATTR_SERVICE: mqtt.SERVICE_PUBLISH
}) })
self.hass.block_till_done() self.hass.block_till_done()
self.assertTrue(not mqtt.MQTT_CLIENT.publish.called) self.assertTrue(not self.hass.data['mqtt'].async_publish.called)
def test_service_call_with_template_payload_renders_template(self): def test_service_call_with_template_payload_renders_template(self):
"""Test the service call with rendered template. """Test the service call with rendered template.
@ -114,8 +133,9 @@ class TestMQTT(unittest.TestCase):
""" """
mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}") mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}")
self.hass.block_till_done() self.hass.block_till_done()
self.assertTrue(mqtt.MQTT_CLIENT.publish.called) self.assertTrue(self.hass.data['mqtt'].async_publish.called)
self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][1], "2") self.assertEqual(
self.hass.data['mqtt'].async_publish.call_args[0][1], "2")
def test_service_call_with_payload_doesnt_render_template(self): def test_service_call_with_payload_doesnt_render_template(self):
"""Test the service call with unrendered template. """Test the service call with unrendered template.
@ -129,7 +149,7 @@ class TestMQTT(unittest.TestCase):
mqtt.ATTR_PAYLOAD: payload, mqtt.ATTR_PAYLOAD: payload,
mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template
}, blocking=True) }, blocking=True)
self.assertFalse(mqtt.MQTT_CLIENT.publish.called) self.assertFalse(self.hass.data['mqtt'].async_publish.called)
def test_service_call_with_ascii_qos_retain_flags(self): def test_service_call_with_ascii_qos_retain_flags(self):
"""Test the service call with args that can be misinterpreted. """Test the service call with args that can be misinterpreted.
@ -142,9 +162,10 @@ class TestMQTT(unittest.TestCase):
mqtt.ATTR_QOS: '2', mqtt.ATTR_QOS: '2',
mqtt.ATTR_RETAIN: 'no' mqtt.ATTR_RETAIN: 'no'
}, blocking=True) }, blocking=True)
self.assertTrue(mqtt.MQTT_CLIENT.publish.called) self.assertTrue(self.hass.data['mqtt'].async_publish.called)
self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][2], 2) self.assertEqual(
self.assertFalse(mqtt.MQTT_CLIENT.publish.call_args[0][3]) self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
def test_subscribe_topic(self): def test_subscribe_topic(self):
"""Test the subscription of a topic.""" """Test the subscription of a topic."""
@ -231,15 +252,12 @@ class TestMQTTCallbacks(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
# mock_mqtt_component(self.hass)
with mock.patch('paho.mqtt.client.Client'): with mock.patch('paho.mqtt.client.Client'):
self.hass.config.components = set() self.hass.config.components = set()
assert setup_component(self.hass, mqtt.DOMAIN, { assert setup_component(self.hass, mqtt.DOMAIN, {
mqtt.DOMAIN: { mqtt.DOMAIN: {
mqtt.CONF_BROKER: 'mock-broker', mqtt.CONF_BROKER: 'mock-broker',
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
mqtt.ATTR_PAYLOAD: 'birth'}
} }
}) })
@ -261,7 +279,8 @@ class TestMQTTCallbacks(unittest.TestCase):
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8')) message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
mqtt.MQTT_CLIENT._mqtt_on_message(None, {'hass': self.hass}, message) self.hass.data['mqtt']._mqtt_on_message(
None, {'hass': self.hass}, message)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(calls)) self.assertEqual(1, len(calls))
@ -273,68 +292,36 @@ class TestMQTTCallbacks(unittest.TestCase):
def test_mqtt_failed_connection_results_in_disconnect(self): def test_mqtt_failed_connection_results_in_disconnect(self):
"""Test if connection failure leads to disconnect.""" """Test if connection failure leads to disconnect."""
for result_code in range(1, 6): for result_code in range(1, 6):
mqtt.MQTT_CLIENT._mqttc = mock.MagicMock() self.hass.data['mqtt']._mqttc = mock.MagicMock()
mqtt.MQTT_CLIENT._mqtt_on_connect(None, {'topics': {}}, 0, self.hass.data['mqtt']._mqtt_on_connect(
result_code) None, {'topics': {}}, 0, result_code)
self.assertTrue(mqtt.MQTT_CLIENT._mqttc.disconnect.called) self.assertTrue(self.hass.data['mqtt']._mqttc.disconnect.called)
def test_mqtt_subscribes_topics_on_connect(self):
"""Test subscription to topic on connect."""
from collections import OrderedDict
prev_topics = OrderedDict()
prev_topics['topic/test'] = 1,
prev_topics['home/sensor'] = 2,
prev_topics['still/pending'] = None
mqtt.MQTT_CLIENT.topics = prev_topics
mqtt.MQTT_CLIENT.progress = {1: 'still/pending'}
# Return values for subscribe calls (rc, mid)
mqtt.MQTT_CLIENT._mqttc.subscribe.side_effect = ((0, 2), (0, 3))
mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0)
self.assertFalse(mqtt.MQTT_CLIENT._mqttc.disconnect.called)
expected = [(topic, qos) for topic, qos in prev_topics.items()
if qos is not None]
self.assertEqual(
expected,
[call[1] for call in mqtt.MQTT_CLIENT._mqttc.subscribe.mock_calls])
self.assertEqual({
1: 'still/pending',
2: 'topic/test',
3: 'home/sensor',
}, mqtt.MQTT_CLIENT.progress)
def test_mqtt_birth_message_on_connect(self): \
# pylint: disable=no-self-use
"""Test birth message on connect."""
mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0)
mqtt.MQTT_CLIENT._mqttc.publish.assert_called_with('birth', 'birth', 0,
False)
def test_mqtt_disconnect_tries_no_reconnect_on_stop(self): def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
"""Test the disconnect tries.""" """Test the disconnect tries."""
mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 0) self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.assertFalse(mqtt.MQTT_CLIENT._mqttc.reconnect.called) self.assertFalse(self.hass.data['mqtt']._mqttc.reconnect.called)
@mock.patch('homeassistant.components.mqtt.time.sleep') @mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
"""Test the re-connect tries.""" """Test the re-connect tries."""
mqtt.MQTT_CLIENT.topics = { self.hass.data['mqtt'].topics = {
'test/topic': 1, 'test/topic': 1,
'test/progress': None 'test/progress': None
} }
mqtt.MQTT_CLIENT.progress = { self.hass.data['mqtt'].progress = {
1: 'test/progress' 1: 'test/progress'
} }
mqtt.MQTT_CLIENT._mqttc.reconnect.side_effect = [1, 1, 1, 0] self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0]
mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 1) self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1)
self.assertTrue(mqtt.MQTT_CLIENT._mqttc.reconnect.called) self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called)
self.assertEqual(4, len(mqtt.MQTT_CLIENT._mqttc.reconnect.mock_calls)) self.assertEqual(
4, len(self.hass.data['mqtt']._mqttc.reconnect.mock_calls))
self.assertEqual([1, 2, 4], self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls]) [call[1][0] for call in mock_sleep.mock_calls])
self.assertEqual({'test/topic': 1}, mqtt.MQTT_CLIENT.topics) self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics)
self.assertEqual({}, mqtt.MQTT_CLIENT.progress) self.assertEqual({}, self.hass.data['mqtt'].progress)
def test_invalid_mqtt_topics(self): def test_invalid_mqtt_topics(self):
"""Test invalid topics.""" """Test invalid topics."""
@ -356,7 +343,7 @@ class TestMQTTCallbacks(unittest.TestCase):
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage(topic, 1, payload) message = MQTTMessage(topic, 1, payload)
with self.assertLogs(level='ERROR') as test_handle: with self.assertLogs(level='ERROR') as test_handle:
mqtt.MQTT_CLIENT._mqtt_on_message( self.hass.data['mqtt']._mqtt_on_message(
None, None,
{'hass': self.hass}, {'hass': self.hass},
message) message)
@ -365,3 +352,47 @@ class TestMQTTCallbacks(unittest.TestCase):
"ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode " "ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode "
"payload from MQTT topic: %s, Payload: " % topic, "payload from MQTT topic: %s, Payload: " % topic,
test_handle.output[0]) test_handle.output[0])
@asyncio.coroutine
def test_birth_message(hass):
"""Test sending birth message."""
mqtt_client = yield from mock_mqtt_client(hass, {
mqtt.CONF_BROKER: 'mock-broker',
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
mqtt.ATTR_PAYLOAD: 'birth'}
})
calls = []
mqtt_client.publish = lambda *args: calls.append(args)
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done()
assert calls[-1] == ('birth', 'birth', 0, False)
@asyncio.coroutine
def test_mqtt_subscribes_topics_on_connect(hass):
"""Test subscription to topic on connect."""
mqtt_client = yield from mock_mqtt_client(hass)
prev_topics = OrderedDict()
prev_topics['topic/test'] = 1,
prev_topics['home/sensor'] = 2,
prev_topics['still/pending'] = None
hass.data['mqtt'].topics = prev_topics
hass.data['mqtt'].progress = {1: 'still/pending'}
# Return values for subscribe calls (rc, mid)
mqtt_client.subscribe.side_effect = ((0, 2), (0, 3))
hass.add_job = mock.MagicMock()
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done()
assert not mqtt_client.disconnect.called
expected = [(topic, qos) for topic, qos in prev_topics.items()
if qos is not None]
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected

View file

@ -4,7 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant, mock_coro
class TestMQTT: class TestMQTT:
@ -21,9 +21,8 @@ class TestMQTT:
@patch('passlib.apps.custom_app_context', Mock(return_value='')) @patch('passlib.apps.custom_app_context', Mock(return_value=''))
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock())) @patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
@patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe',
Mock(return_value=MagicMock()))
@patch('hbmqtt.broker.Broker', Mock(return_value=MagicMock())) @patch('hbmqtt.broker.Broker', Mock(return_value=MagicMock()))
@patch('hbmqtt.broker.Broker.start', Mock(return_value=mock_coro()))
@patch('homeassistant.components.mqtt.MQTT') @patch('homeassistant.components.mqtt.MQTT')
def test_creating_config_with_http_pass(self, mock_mqtt): def test_creating_config_with_http_pass(self, mock_mqtt):
"""Test if the MQTT server gets started and subscribe/publish msg.""" """Test if the MQTT server gets started and subscribe/publish msg."""
@ -46,7 +45,7 @@ class TestMQTT:
assert mock_mqtt.mock_calls[0][1][6] is None assert mock_mqtt.mock_calls[0][1][6] is None
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock())) @patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
@patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe') @patch('hbmqtt.broker.Broker.start', return_value=mock_coro())
def test_broker_config_fails(self, mock_run): def test_broker_config_fails(self, mock_run):
"""Test if the MQTT component fails if server fails.""" """Test if the MQTT component fails if server fails."""
from hbmqtt.broker import BrokerException from hbmqtt.broker import BrokerException

View file

@ -72,7 +72,7 @@ class TestSensorMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'beer on', 2, False), self.assertEqual(('command-topic', 'beer on', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('switch.test') state = self.hass.states.get('switch.test')
self.assertEqual(STATE_ON, state.state) self.assertEqual(STATE_ON, state.state)
@ -80,7 +80,7 @@ class TestSensorMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(('command-topic', 'beer off', 2, False), self.assertEqual(('command-topic', 'beer off', 2, False),
self.mock_publish.mock_calls[-1][1]) self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('switch.test') state = self.hass.states.get('switch.test')
self.assertEqual(STATE_OFF, state.state) self.assertEqual(STATE_OFF, state.state)