MQTT convert to async (#6064)
* Migrate mqtt to async * address paulus comment / convert it complet async * adress paulus comment / remove future * Automation triggers should be async * Fix MQTT async calls * Show that event helpers are callbacks * Fix tests * Lint
This commit is contained in:
parent
fa2c1dafdf
commit
e1cbd6b4c0
25 changed files with 356 additions and 231 deletions
|
@ -412,7 +412,7 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
|
||||||
if platform is None:
|
if platform is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
remove = platform.async_trigger(hass, conf, action)
|
remove = yield from platform.async_trigger(hass, conf, action)
|
||||||
|
|
||||||
if not remove:
|
if not remove:
|
||||||
_LOGGER.error("Error setting up trigger %s", name)
|
_LOGGER.error("Error setting up trigger %s", name)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer event listening automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#event-trigger
|
at https://home-assistant.io/components/automation/#event-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for events based on configuration."""
|
"""Listen for events based on configuration."""
|
||||||
event_type = config.get(CONF_EVENT_TYPE)
|
event_type = config.get(CONF_EVENT_TYPE)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Trigger an automation when a LiteJet switch is released.
|
||||||
For more details about this platform, please refer to the documentation at
|
For more details about this platform, please refer to the documentation at
|
||||||
https://home-assistant.io/components/automation.litejet/
|
https://home-assistant.io/components/automation.litejet/
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -32,6 +33,7 @@ TRIGGER_SCHEMA = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for events based on configuration."""
|
"""Listen for events based on configuration."""
|
||||||
number = config.get(CONF_NUMBER)
|
number = config.get(CONF_NUMBER)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer MQTT listening automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#mqtt-trigger
|
at https://home-assistant.io/components/automation/#mqtt-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
topic = config.get(CONF_TOPIC)
|
topic = config.get(CONF_TOPIC)
|
||||||
|
@ -49,4 +51,6 @@ def async_trigger(hass, config, action):
|
||||||
'trigger': data
|
'trigger': data
|
||||||
})
|
})
|
||||||
|
|
||||||
return mqtt.async_subscribe(hass, topic, mqtt_automation_listener)
|
remove = yield from mqtt.async_subscribe(
|
||||||
|
hass, topic, mqtt_automation_listener)
|
||||||
|
return remove
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer numeric state listening automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#numeric-state-trigger
|
at https://home-assistant.io/components/automation/#numeric-state-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
entity_id = config.get(CONF_ENTITY_ID)
|
entity_id = config.get(CONF_ENTITY_ID)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer state listening automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#state-trigger
|
at https://home-assistant.io/components/automation/#state-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
@ -34,6 +35,7 @@ TRIGGER_SCHEMA = vol.All(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
entity_id = config.get(CONF_ENTITY_ID)
|
entity_id = config.get(CONF_ENTITY_ID)
|
||||||
|
@ -97,6 +99,7 @@ def async_trigger(hass, config, action):
|
||||||
unsub = async_track_state_change(
|
unsub = async_track_state_change(
|
||||||
hass, entity_id, state_automation_listener, from_state, to_state)
|
hass, entity_id, state_automation_listener, from_state, to_state)
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_remove():
|
def async_remove():
|
||||||
"""Remove state listeners async."""
|
"""Remove state listeners async."""
|
||||||
unsub()
|
unsub()
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer sun based automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#sun-trigger
|
at https://home-assistant.io/components/automation/#sun-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for events based on configuration."""
|
"""Listen for events based on configuration."""
|
||||||
event = config.get(CONF_EVENT)
|
event = config.get(CONF_EVENT)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer template automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#template-trigger
|
at https://home-assistant.io/components/automation/#template-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -22,6 +23,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
value_template = config.get(CONF_VALUE_TEMPLATE)
|
value_template = config.get(CONF_VALUE_TEMPLATE)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer time listening automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#time-trigger
|
at https://home-assistant.io/components/automation/#time-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -29,6 +30,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
|
||||||
CONF_SECONDS, CONF_AFTER))
|
CONF_SECONDS, CONF_AFTER))
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
if CONF_AFTER in config:
|
if CONF_AFTER in config:
|
||||||
|
|
|
@ -4,6 +4,7 @@ Offer zone automation rules.
|
||||||
For more details about this automation rule, please refer to the documentation
|
For more details about this automation rule, please refer to the documentation
|
||||||
at https://home-assistant.io/components/automation/#zone-trigger
|
at https://home-assistant.io/components/automation/#zone-trigger
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_trigger(hass, config, action):
|
def async_trigger(hass, config, action):
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
entity_id = config.get(CONF_ENTITY_ID)
|
entity_id = config.get(CONF_ENTITY_ID)
|
||||||
|
|
|
@ -4,6 +4,7 @@ Support for MQTT message handling.
|
||||||
For more details about this component, please refer to the documentation at
|
For more details about this component, please refer to the documentation at
|
||||||
https://home-assistant.io/components/mqtt/
|
https://home-assistant.io/components/mqtt/
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
@ -12,11 +13,12 @@ import time
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.bootstrap import prepare_setup_platform
|
from homeassistant.bootstrap import async_prepare_setup_platform
|
||||||
from homeassistant.config import load_yaml_config_file
|
from homeassistant.config import load_yaml_config_file
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import template, config_validation as cv
|
from homeassistant.helpers import template, config_validation as cv
|
||||||
from homeassistant.helpers.event import threaded_listener_factory
|
from homeassistant.util.async import (
|
||||||
|
run_coroutine_threadsafe, run_callback_threadsafe)
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_VALUE_TEMPLATE,
|
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_VALUE_TEMPLATE,
|
||||||
CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_PAYLOAD)
|
CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_PAYLOAD)
|
||||||
|
@ -26,7 +28,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DOMAIN = 'mqtt'
|
DOMAIN = 'mqtt'
|
||||||
|
|
||||||
MQTT_CLIENT = None
|
DATA_MQTT = 'mqtt'
|
||||||
|
|
||||||
SERVICE_PUBLISH = 'publish'
|
SERVICE_PUBLISH = 'publish'
|
||||||
EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
|
EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
|
||||||
|
@ -183,11 +185,11 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
|
||||||
hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
|
hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@asyncio.coroutine
|
||||||
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
|
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
|
||||||
"""Subscribe to an MQTT topic."""
|
"""Subscribe to an MQTT topic."""
|
||||||
@callback
|
@callback
|
||||||
def mqtt_topic_subscriber(event):
|
def async_mqtt_topic_subscriber(event):
|
||||||
"""Match subscribed MQTT topic."""
|
"""Match subscribed MQTT topic."""
|
||||||
if not _match_topic(topic, event.data[ATTR_TOPIC]):
|
if not _match_topic(topic, event.data[ATTR_TOPIC]):
|
||||||
return
|
return
|
||||||
|
@ -195,61 +197,82 @@ def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
|
||||||
hass.async_run_job(msg_callback, event.data[ATTR_TOPIC],
|
hass.async_run_job(msg_callback, event.data[ATTR_TOPIC],
|
||||||
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
|
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
|
||||||
|
|
||||||
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,
|
async_remove = hass.bus.async_listen(
|
||||||
mqtt_topic_subscriber)
|
EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
|
||||||
|
|
||||||
# Future: track subscriber count and unsubscribe in remove
|
|
||||||
MQTT_CLIENT.subscribe(topic, qos)
|
|
||||||
|
|
||||||
|
yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
|
||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
|
||||||
subscribe = threaded_listener_factory(async_subscribe)
|
"""Subscribe to an MQTT topic."""
|
||||||
|
async_remove = run_coroutine_threadsafe(
|
||||||
|
async_subscribe(hass, topic, msg_callback, qos),
|
||||||
|
hass.loop
|
||||||
|
).result()
|
||||||
|
|
||||||
|
def remove():
|
||||||
|
"""Remove listener convert."""
|
||||||
|
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||||
|
|
||||||
|
return remove
|
||||||
|
|
||||||
|
|
||||||
def _setup_server(hass, config):
|
@asyncio.coroutine
|
||||||
"""Try to start embedded MQTT broker."""
|
def _async_setup_server(hass, config):
|
||||||
|
"""Try to start embedded MQTT broker.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
conf = config.get(DOMAIN, {})
|
conf = config.get(DOMAIN, {})
|
||||||
|
|
||||||
# Only setup if embedded config passed in or no broker specified
|
# Only setup if embedded config passed in or no broker specified
|
||||||
if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
|
if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
server = prepare_setup_platform(hass, config, DOMAIN, 'server')
|
server = yield from async_prepare_setup_platform(
|
||||||
|
hass, config, DOMAIN, 'server')
|
||||||
|
|
||||||
if server is None:
|
if server is None:
|
||||||
_LOGGER.error("Unable to load embedded server")
|
_LOGGER.error("Unable to load embedded server")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
success, broker_config = server.start(hass, conf.get(CONF_EMBEDDED))
|
success, broker_config = \
|
||||||
|
yield from server.async_start(hass, conf.get(CONF_EMBEDDED))
|
||||||
|
|
||||||
return success and broker_config
|
return success and broker_config
|
||||||
|
|
||||||
|
|
||||||
def _setup_discovery(hass, config):
|
@asyncio.coroutine
|
||||||
"""Try to start the discovery of MQTT devices."""
|
def _async_setup_discovery(hass, config):
|
||||||
|
"""Try to start the discovery of MQTT devices.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
conf = config.get(DOMAIN, {})
|
conf = config.get(DOMAIN, {})
|
||||||
|
|
||||||
discovery = prepare_setup_platform(hass, config, DOMAIN, 'discovery')
|
discovery = yield from async_prepare_setup_platform(
|
||||||
|
hass, config, DOMAIN, 'discovery')
|
||||||
|
|
||||||
if discovery is None:
|
if discovery is None:
|
||||||
_LOGGER.error("Unable to load MQTT discovery")
|
_LOGGER.error("Unable to load MQTT discovery")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
success = discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config)
|
success = yield from discovery.async_start(
|
||||||
|
hass, conf[CONF_DISCOVERY_PREFIX], config)
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
def setup(hass, config):
|
@asyncio.coroutine
|
||||||
|
def async_setup(hass, config):
|
||||||
"""Start the MQTT protocol service."""
|
"""Start the MQTT protocol service."""
|
||||||
conf = config.get(DOMAIN, {})
|
conf = config.get(DOMAIN, {})
|
||||||
|
|
||||||
client_id = conf.get(CONF_CLIENT_ID)
|
client_id = conf.get(CONF_CLIENT_ID)
|
||||||
keepalive = conf.get(CONF_KEEPALIVE)
|
keepalive = conf.get(CONF_KEEPALIVE)
|
||||||
|
|
||||||
broker_config = _setup_server(hass, config)
|
broker_config = yield from _async_setup_server(hass, config)
|
||||||
|
|
||||||
if CONF_BROKER in conf:
|
if CONF_BROKER in conf:
|
||||||
broker = conf[CONF_BROKER]
|
broker = conf[CONF_BROKER]
|
||||||
|
@ -283,27 +306,31 @@ def setup(hass, config):
|
||||||
will_message = conf.get(CONF_WILL_MESSAGE)
|
will_message = conf.get(CONF_WILL_MESSAGE)
|
||||||
birth_message = conf.get(CONF_BIRTH_MESSAGE)
|
birth_message = conf.get(CONF_BIRTH_MESSAGE)
|
||||||
|
|
||||||
global MQTT_CLIENT
|
|
||||||
try:
|
try:
|
||||||
MQTT_CLIENT = MQTT(hass, broker, port, client_id, keepalive,
|
hass.data[DATA_MQTT] = MQTT(
|
||||||
username, password, certificate, client_key,
|
hass, broker, port, client_id, keepalive, username, password,
|
||||||
client_cert, tls_insecure, protocol, will_message,
|
certificate, client_key, client_cert, tls_insecure, protocol,
|
||||||
birth_message)
|
will_message, birth_message)
|
||||||
except socket.error:
|
except socket.error:
|
||||||
_LOGGER.exception("Can't connect to the broker. "
|
_LOGGER.exception("Can't connect to the broker. "
|
||||||
"Please check your settings and the broker itself")
|
"Please check your settings and the broker itself")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def stop_mqtt(event):
|
@asyncio.coroutine
|
||||||
|
def async_stop_mqtt(event):
|
||||||
"""Stop MQTT component."""
|
"""Stop MQTT component."""
|
||||||
MQTT_CLIENT.stop()
|
yield from hass.data[DATA_MQTT].async_stop()
|
||||||
|
|
||||||
def start_mqtt(event):
|
@asyncio.coroutine
|
||||||
|
def async_start_mqtt(event):
|
||||||
"""Launch MQTT component when Home Assistant starts up."""
|
"""Launch MQTT component when Home Assistant starts up."""
|
||||||
MQTT_CLIENT.start()
|
yield from hass.data[DATA_MQTT].async_start()
|
||||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_mqtt)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
|
||||||
|
|
||||||
def publish_service(call):
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, async_start_mqtt)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def async_publish_service(call):
|
||||||
"""Handle MQTT publish service calls."""
|
"""Handle MQTT publish service calls."""
|
||||||
msg_topic = call.data[ATTR_TOPIC]
|
msg_topic = call.data[ATTR_TOPIC]
|
||||||
payload = call.data.get(ATTR_PAYLOAD)
|
payload = call.data.get(ATTR_PAYLOAD)
|
||||||
|
@ -312,26 +339,28 @@ def setup(hass, config):
|
||||||
retain = call.data[ATTR_RETAIN]
|
retain = call.data[ATTR_RETAIN]
|
||||||
try:
|
try:
|
||||||
if payload_template is not None:
|
if payload_template is not None:
|
||||||
payload = template.Template(payload_template, hass).render()
|
payload = \
|
||||||
|
template.Template(payload_template, hass).async_render()
|
||||||
except template.jinja2.TemplateError as exc:
|
except template.jinja2.TemplateError as exc:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"Unable to publish to '%s': rendering payload template of "
|
"Unable to publish to '%s': rendering payload template of "
|
||||||
"'%s' failed because %s",
|
"'%s' failed because %s",
|
||||||
msg_topic, payload_template, exc)
|
msg_topic, payload_template, exc)
|
||||||
return
|
return
|
||||||
MQTT_CLIENT.publish(msg_topic, payload, qos, retain)
|
|
||||||
|
|
||||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_mqtt)
|
yield from hass.data[DATA_MQTT].async_publish(
|
||||||
|
msg_topic, payload, qos, retain)
|
||||||
|
|
||||||
descriptions = load_yaml_config_file(
|
descriptions = yield from hass.loop.run_in_executor(
|
||||||
os.path.join(os.path.dirname(__file__), 'services.yaml'))
|
None, load_yaml_config_file, os.path.join(
|
||||||
|
os.path.dirname(__file__), 'services.yaml'))
|
||||||
|
|
||||||
hass.services.register(DOMAIN, SERVICE_PUBLISH, publish_service,
|
hass.services.async_register(
|
||||||
descriptions.get(SERVICE_PUBLISH),
|
DOMAIN, SERVICE_PUBLISH, async_publish_service,
|
||||||
schema=MQTT_PUBLISH_SCHEMA)
|
descriptions.get(SERVICE_PUBLISH), schema=MQTT_PUBLISH_SCHEMA)
|
||||||
|
|
||||||
if conf.get(CONF_DISCOVERY):
|
if conf.get(CONF_DISCOVERY):
|
||||||
_setup_discovery(hass, config)
|
yield from _async_setup_discovery(hass, config)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -349,6 +378,7 @@ class MQTT(object):
|
||||||
self.topics = {}
|
self.topics = {}
|
||||||
self.progress = {}
|
self.progress = {}
|
||||||
self.birth_message = birth_message
|
self.birth_message = birth_message
|
||||||
|
self._mqttc = None
|
||||||
|
|
||||||
if protocol == PROTOCOL_31:
|
if protocol == PROTOCOL_31:
|
||||||
proto = mqtt.MQTTv31
|
proto = mqtt.MQTTv31
|
||||||
|
@ -364,8 +394,8 @@ class MQTT(object):
|
||||||
self._mqttc.username_pw_set(username, password)
|
self._mqttc.username_pw_set(username, password)
|
||||||
|
|
||||||
if certificate is not None:
|
if certificate is not None:
|
||||||
self._mqttc.tls_set(certificate, certfile=client_cert,
|
self._mqttc.tls_set(
|
||||||
keyfile=client_key)
|
certificate, certfile=client_cert, keyfile=client_key)
|
||||||
|
|
||||||
if tls_insecure is not None:
|
if tls_insecure is not None:
|
||||||
self._mqttc.tls_insecure_set(tls_insecure)
|
self._mqttc.tls_insecure_set(tls_insecure)
|
||||||
|
@ -375,40 +405,69 @@ class MQTT(object):
|
||||||
self._mqttc.on_connect = self._mqtt_on_connect
|
self._mqttc.on_connect = self._mqtt_on_connect
|
||||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||||
self._mqttc.on_message = self._mqtt_on_message
|
self._mqttc.on_message = self._mqtt_on_message
|
||||||
|
|
||||||
if will_message:
|
if will_message:
|
||||||
self._mqttc.will_set(will_message.get(ATTR_TOPIC),
|
self._mqttc.will_set(will_message.get(ATTR_TOPIC),
|
||||||
will_message.get(ATTR_PAYLOAD),
|
will_message.get(ATTR_PAYLOAD),
|
||||||
will_message.get(ATTR_QOS),
|
will_message.get(ATTR_QOS),
|
||||||
will_message.get(ATTR_RETAIN))
|
will_message.get(ATTR_RETAIN))
|
||||||
self._mqttc.connect(broker, port, keepalive)
|
|
||||||
|
|
||||||
def publish(self, topic, payload, qos, retain):
|
self._mqttc.connect_async(broker, port, keepalive)
|
||||||
"""Publish a MQTT message."""
|
|
||||||
self._mqttc.publish(topic, payload, qos, retain)
|
|
||||||
|
|
||||||
def start(self):
|
def async_publish(self, topic, payload, qos, retain):
|
||||||
"""Run the MQTT client."""
|
"""Publish a MQTT message.
|
||||||
self._mqttc.loop_start()
|
|
||||||
|
|
||||||
def stop(self):
|
This method must be run in the event loop and returns a coroutine.
|
||||||
"""Stop the MQTT client."""
|
"""
|
||||||
self._mqttc.disconnect()
|
return self.hass.loop.run_in_executor(
|
||||||
self._mqttc.loop_stop()
|
None, self._mqttc.publish, topic, payload, qos, retain)
|
||||||
|
|
||||||
def subscribe(self, topic, qos):
|
def async_start(self):
|
||||||
"""Subscribe to a topic."""
|
"""Run the MQTT client.
|
||||||
assert isinstance(topic, str)
|
|
||||||
|
This method must be run in the event loop and returns a coroutine.
|
||||||
|
"""
|
||||||
|
return self.hass.loop.run_in_executor(None, self._mqttc.loop_start)
|
||||||
|
|
||||||
|
def async_stop(self):
|
||||||
|
"""Stop the MQTT client.
|
||||||
|
|
||||||
|
This method must be run in the event loop and returns a coroutine.
|
||||||
|
"""
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the MQTT client."""
|
||||||
|
self._mqttc.disconnect()
|
||||||
|
self._mqttc.loop_stop()
|
||||||
|
|
||||||
|
return self.hass.loop.run_in_executor(None, stop)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def async_subscribe(self, topic, qos):
|
||||||
|
"""Subscribe to a topic.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
|
if not isinstance(topic, str):
|
||||||
|
raise HomeAssistantError("topic need to be a string!")
|
||||||
|
|
||||||
if topic in self.topics:
|
if topic in self.topics:
|
||||||
return
|
return
|
||||||
result, mid = self._mqttc.subscribe(topic, qos)
|
result, mid = yield from self.hass.loop.run_in_executor(
|
||||||
|
None, self._mqttc.subscribe, topic, qos)
|
||||||
|
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
self.progress[mid] = topic
|
self.progress[mid] = topic
|
||||||
self.topics[topic] = None
|
self.topics[topic] = None
|
||||||
|
|
||||||
def unsubscribe(self, topic):
|
@asyncio.coroutine
|
||||||
"""Unsubscribe from topic."""
|
def async_unsubscribe(self, topic):
|
||||||
result, mid = self._mqttc.unsubscribe(topic)
|
"""Unsubscribe from topic.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
|
result, mid = yield from self.hass.loop.run_in_executor(
|
||||||
|
None, self._mqttc.unsubscribe, topic)
|
||||||
|
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
self.progress[mid] = topic
|
self.progress[mid] = topic
|
||||||
|
|
||||||
|
@ -437,12 +496,14 @@ class MQTT(object):
|
||||||
for topic, qos in old_topics.items():
|
for topic, qos in old_topics.items():
|
||||||
# qos is None if we were in process of subscribing
|
# qos is None if we were in process of subscribing
|
||||||
if qos is not None:
|
if qos is not None:
|
||||||
self.subscribe(topic, qos)
|
self.hass.add_job(self.async_subscribe, topic, qos)
|
||||||
|
|
||||||
if self.birth_message:
|
if self.birth_message:
|
||||||
self.publish(self.birth_message.get(ATTR_TOPIC),
|
self.hass.add_job(self.async_publish(
|
||||||
self.birth_message.get(ATTR_PAYLOAD),
|
self.birth_message.get(ATTR_TOPIC),
|
||||||
self.birth_message.get(ATTR_QOS),
|
self.birth_message.get(ATTR_PAYLOAD),
|
||||||
self.birth_message.get(ATTR_RETAIN))
|
self.birth_message.get(ATTR_QOS),
|
||||||
|
self.birth_message.get(ATTR_RETAIN)))
|
||||||
|
|
||||||
def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
|
def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
|
||||||
"""Subscribe successful callback."""
|
"""Subscribe successful callback."""
|
||||||
|
|
|
@ -9,7 +9,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from homeassistant.core import callback
|
|
||||||
import homeassistant.components.mqtt as mqtt
|
import homeassistant.components.mqtt as mqtt
|
||||||
from homeassistant.components.mqtt import DOMAIN
|
from homeassistant.components.mqtt import DOMAIN
|
||||||
from homeassistant.helpers.discovery import async_load_platform
|
from homeassistant.helpers.discovery import async_load_platform
|
||||||
|
@ -24,7 +23,7 @@ TOPIC_MATCHER = re.compile(
|
||||||
SUPPORTED_COMPONENTS = ['binary_sensor', 'sensor']
|
SUPPORTED_COMPONENTS = ['binary_sensor', 'sensor']
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@asyncio.coroutine
|
||||||
def async_start(hass, discovery_topic, hass_config):
|
def async_start(hass, discovery_topic, hass_config):
|
||||||
"""Initialization of MQTT Discovery."""
|
"""Initialization of MQTT Discovery."""
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -56,7 +55,7 @@ def async_start(hass, discovery_topic, hass_config):
|
||||||
yield from async_load_platform(
|
yield from async_load_platform(
|
||||||
hass, component, DOMAIN, payload, hass_config)
|
hass, component, DOMAIN, payload, hass_config)
|
||||||
|
|
||||||
mqtt.async_subscribe(hass, discovery_topic + '/#',
|
yield from mqtt.async_subscribe(
|
||||||
async_device_message_received, 0)
|
hass, discovery_topic + '/#', async_device_message_received, 0)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -4,15 +4,14 @@ Support for a local MQTT broker.
|
||||||
For more details about this component, please refer to the documentation at
|
For more details about this component, please refer to the documentation at
|
||||||
https://home-assistant.io/components/mqtt/#use-the-embedded-broker
|
https://home-assistant.io/components/mqtt/#use-the-embedded-broker
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.util.async import run_coroutine_threadsafe
|
|
||||||
|
|
||||||
REQUIREMENTS = ['hbmqtt==0.8']
|
REQUIREMENTS = ['hbmqtt==0.8']
|
||||||
DEPENDENCIES = ['http']
|
DEPENDENCIES = ['http']
|
||||||
|
@ -29,8 +28,12 @@ HBMQTT_CONFIG_SCHEMA = vol.Any(None, vol.Schema({
|
||||||
}, extra=vol.ALLOW_EXTRA))
|
}, extra=vol.ALLOW_EXTRA))
|
||||||
|
|
||||||
|
|
||||||
def start(hass, server_config):
|
@asyncio.coroutine
|
||||||
"""Initialize MQTT Server."""
|
def async_start(hass, server_config):
|
||||||
|
"""Initialize MQTT Server.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
from hbmqtt.broker import Broker, BrokerException
|
from hbmqtt.broker import Broker, BrokerException
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -42,19 +45,20 @@ def start(hass, server_config):
|
||||||
client_config = None
|
client_config = None
|
||||||
|
|
||||||
broker = Broker(server_config, hass.loop)
|
broker = Broker(server_config, hass.loop)
|
||||||
run_coroutine_threadsafe(broker.start(), hass.loop).result()
|
yield from broker.start()
|
||||||
except BrokerException:
|
except BrokerException:
|
||||||
logging.getLogger(__name__).exception('Error initializing MQTT server')
|
logging.getLogger(__name__).exception('Error initializing MQTT server')
|
||||||
return False, None
|
return False, None
|
||||||
finally:
|
finally:
|
||||||
passwd.close()
|
passwd.close()
|
||||||
|
|
||||||
@callback
|
@asyncio.coroutine
|
||||||
def shutdown_mqtt_server(event):
|
def async_shutdown_mqtt_server(event):
|
||||||
"""Shut down the MQTT server."""
|
"""Shut down the MQTT server."""
|
||||||
hass.async_add_job(broker.shutdown())
|
yield from broker.shutdown()
|
||||||
|
|
||||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, shutdown_mqtt_server)
|
hass.bus.async_listen_once(
|
||||||
|
EVENT_HOMEASSISTANT_STOP, async_shutdown_mqtt_server)
|
||||||
|
|
||||||
return True, client_config
|
return True, client_config
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ def threaded_listener_factory(async_factory):
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_state_change(hass, entity_ids, action, from_state=None,
|
def async_track_state_change(hass, entity_ids, action, from_state=None,
|
||||||
to_state=None):
|
to_state=None):
|
||||||
"""Track specific state changes.
|
"""Track specific state changes.
|
||||||
|
@ -84,6 +85,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
|
||||||
track_state_change = threaded_listener_factory(async_track_state_change)
|
track_state_change = threaded_listener_factory(async_track_state_change)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_template(hass, template, action, variables=None):
|
def async_track_template(hass, template, action, variables=None):
|
||||||
"""Add a listener that track state changes with template condition."""
|
"""Add a listener that track state changes with template condition."""
|
||||||
from . import condition
|
from . import condition
|
||||||
|
@ -111,6 +113,7 @@ def async_track_template(hass, template, action, variables=None):
|
||||||
track_template = threaded_listener_factory(async_track_template)
|
track_template = threaded_listener_factory(async_track_template)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_point_in_time(hass, action, point_in_time):
|
def async_track_point_in_time(hass, action, point_in_time):
|
||||||
"""Add a listener that fires once after a specific point in time."""
|
"""Add a listener that fires once after a specific point in time."""
|
||||||
utc_point_in_time = dt_util.as_utc(point_in_time)
|
utc_point_in_time = dt_util.as_utc(point_in_time)
|
||||||
|
@ -127,6 +130,7 @@ def async_track_point_in_time(hass, action, point_in_time):
|
||||||
track_point_in_time = threaded_listener_factory(async_track_point_in_time)
|
track_point_in_time = threaded_listener_factory(async_track_point_in_time)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_point_in_utc_time(hass, action, point_in_time):
|
def async_track_point_in_utc_time(hass, action, point_in_time):
|
||||||
"""Add a listener that fires once after a specific point in UTC time."""
|
"""Add a listener that fires once after a specific point in UTC time."""
|
||||||
# Ensure point_in_time is UTC
|
# Ensure point_in_time is UTC
|
||||||
|
@ -160,6 +164,7 @@ track_point_in_utc_time = threaded_listener_factory(
|
||||||
async_track_point_in_utc_time)
|
async_track_point_in_utc_time)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_time_interval(hass, action, interval):
|
def async_track_time_interval(hass, action, interval):
|
||||||
"""Add a listener that fires repetitively at every timedelta interval."""
|
"""Add a listener that fires repetitively at every timedelta interval."""
|
||||||
remove = None
|
remove = None
|
||||||
|
@ -189,6 +194,7 @@ def async_track_time_interval(hass, action, interval):
|
||||||
track_time_interval = threaded_listener_factory(async_track_time_interval)
|
track_time_interval = threaded_listener_factory(async_track_time_interval)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_sunrise(hass, action, offset=None):
|
def async_track_sunrise(hass, action, offset=None):
|
||||||
"""Add a listener that will fire a specified offset from sunrise daily."""
|
"""Add a listener that will fire a specified offset from sunrise daily."""
|
||||||
from homeassistant.components import sun
|
from homeassistant.components import sun
|
||||||
|
@ -225,6 +231,7 @@ def async_track_sunrise(hass, action, offset=None):
|
||||||
track_sunrise = threaded_listener_factory(async_track_sunrise)
|
track_sunrise = threaded_listener_factory(async_track_sunrise)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_sunset(hass, action, offset=None):
|
def async_track_sunset(hass, action, offset=None):
|
||||||
"""Add a listener that will fire a specified offset from sunset daily."""
|
"""Add a listener that will fire a specified offset from sunset daily."""
|
||||||
from homeassistant.components import sun
|
from homeassistant.components import sun
|
||||||
|
@ -261,6 +268,7 @@ def async_track_sunset(hass, action, offset=None):
|
||||||
track_sunset = threaded_listener_factory(async_track_sunset)
|
track_sunset = threaded_listener_factory(async_track_sunset)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
||||||
hour=None, minute=None, second=None,
|
hour=None, minute=None, second=None,
|
||||||
local=False):
|
local=False):
|
||||||
|
@ -305,6 +313,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
||||||
track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
|
track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_track_time_change(hass, action, year=None, month=None, day=None,
|
def async_track_time_change(hass, action, year=None, month=None, day=None,
|
||||||
hour=None, minute=None, second=None):
|
hour=None, minute=None, second=None):
|
||||||
"""Add a listener that will fire if UTC time matches a pattern."""
|
"""Add a listener that will fire if UTC time matches a pattern."""
|
||||||
|
|
|
@ -86,7 +86,13 @@ def async_test_home_assistant(loop):
|
||||||
loop._thread_ident = threading.get_ident()
|
loop._thread_ident = threading.get_ident()
|
||||||
|
|
||||||
hass = ha.HomeAssistant(loop)
|
hass = ha.HomeAssistant(loop)
|
||||||
hass.async_track_tasks()
|
|
||||||
|
def async_add_job(target, *args):
|
||||||
|
if isinstance(target, MagicMock):
|
||||||
|
return
|
||||||
|
hass._async_add_job_tracking(target, *args)
|
||||||
|
|
||||||
|
hass.async_add_job = async_add_job
|
||||||
|
|
||||||
hass.config.location_name = 'test home'
|
hass.config.location_name = 'test home'
|
||||||
hass.config.config_dir = get_test_config_dir()
|
hass.config.config_dir = get_test_config_dir()
|
||||||
|
|
|
@ -111,7 +111,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
|
||||||
alarm_control_panel.alarm_arm_home(self.hass)
|
alarm_control_panel.alarm_arm_home(self.hass)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(('alarm/command', 'ARM_HOME', 0, False),
|
self.assertEqual(('alarm/command', 'ARM_HOME', 0, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
|
|
||||||
def test_arm_home_not_publishes_mqtt_with_invalid_code(self):
|
def test_arm_home_not_publishes_mqtt_with_invalid_code(self):
|
||||||
"""Test not publishing of MQTT messages with invalid code."""
|
"""Test not publishing of MQTT messages with invalid code."""
|
||||||
|
@ -146,7 +146,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
|
||||||
alarm_control_panel.alarm_arm_away(self.hass)
|
alarm_control_panel.alarm_arm_away(self.hass)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False),
|
self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
|
|
||||||
def test_arm_away_not_publishes_mqtt_with_invalid_code(self):
|
def test_arm_away_not_publishes_mqtt_with_invalid_code(self):
|
||||||
"""Test not publishing of MQTT messages with invalid code."""
|
"""Test not publishing of MQTT messages with invalid code."""
|
||||||
|
@ -181,7 +181,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
|
||||||
alarm_control_panel.alarm_disarm(self.hass)
|
alarm_control_panel.alarm_disarm(self.hass)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(('alarm/command', 'DISARM', 0, False),
|
self.assertEqual(('alarm/command', 'DISARM', 0, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
|
|
||||||
def test_disarm_not_publishes_mqtt_with_invalid_code(self):
|
def test_disarm_not_publishes_mqtt_with_invalid_code(self):
|
||||||
"""Test not publishing of MQTT messages with invalid code."""
|
"""Test not publishing of MQTT messages with invalid code."""
|
||||||
|
|
|
@ -118,7 +118,7 @@ class TestCoverMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'OPEN', 0, False),
|
self.assertEqual(('command-topic', 'OPEN', 0, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('cover.test')
|
state = self.hass.states.get('cover.test')
|
||||||
self.assertEqual(STATE_OPEN, state.state)
|
self.assertEqual(STATE_OPEN, state.state)
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ class TestCoverMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'CLOSE', 0, False),
|
self.assertEqual(('command-topic', 'CLOSE', 0, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('cover.test')
|
state = self.hass.states.get('cover.test')
|
||||||
self.assertEqual(STATE_CLOSED, state.state)
|
self.assertEqual(STATE_CLOSED, state.state)
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ class TestCoverMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'OPEN', 2, False),
|
self.assertEqual(('command-topic', 'OPEN', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('cover.test')
|
state = self.hass.states.get('cover.test')
|
||||||
self.assertEqual(STATE_UNKNOWN, state.state)
|
self.assertEqual(STATE_UNKNOWN, state.state)
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ class TestCoverMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'CLOSE', 2, False),
|
self.assertEqual(('command-topic', 'CLOSE', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('cover.test')
|
state = self.hass.states.get('cover.test')
|
||||||
self.assertEqual(STATE_UNKNOWN, state.state)
|
self.assertEqual(STATE_UNKNOWN, state.state)
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ class TestCoverMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'STOP', 2, False),
|
self.assertEqual(('command-topic', 'STOP', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('cover.test')
|
state = self.hass.states.get('cover.test')
|
||||||
self.assertEqual(STATE_UNKNOWN, state.state)
|
self.assertEqual(STATE_UNKNOWN, state.state)
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,7 @@ light:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from homeassistant.bootstrap import setup_component
|
from homeassistant.bootstrap import setup_component
|
||||||
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ASSUMED_STATE
|
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ASSUMED_STATE
|
||||||
|
@ -328,7 +329,7 @@ class TestLightMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/set', 'on', 2, False),
|
self.assertEqual(('test_light_rgb/set', 'on', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_ON, state.state)
|
self.assertEqual(STATE_ON, state.state)
|
||||||
|
|
||||||
|
@ -336,27 +337,20 @@ class TestLightMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/set', 'off', 2, False),
|
self.assertEqual(('test_light_rgb/set', 'off', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_OFF, state.state)
|
self.assertEqual(STATE_OFF, state.state)
|
||||||
|
|
||||||
|
self.mock_publish.reset_mock()
|
||||||
light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75],
|
light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75],
|
||||||
brightness=50)
|
brightness=50)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
# Calls are threaded so we need to reorder them
|
self.mock_publish().async_publish.assert_has_calls([
|
||||||
bright_call, rgb_call, state_call = \
|
mock.call('test_light_rgb/set', 'on', 2, False),
|
||||||
sorted((call[1] for call in self.mock_publish.mock_calls[-3:]),
|
mock.call('test_light_rgb/rgb/set', '75,75,75', 2, False),
|
||||||
key=lambda call: call[0])
|
mock.call('test_light_rgb/brightness/set', 50, 2, False),
|
||||||
|
], any_order=True)
|
||||||
self.assertEqual(('test_light_rgb/set', 'on', 2, False),
|
|
||||||
state_call)
|
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/rgb/set', '75,75,75', 2, False),
|
|
||||||
rgb_call)
|
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/brightness/set', 50, 2, False),
|
|
||||||
bright_call)
|
|
||||||
|
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_ON, state.state)
|
self.assertEqual(STATE_ON, state.state)
|
||||||
|
|
|
@ -172,7 +172,7 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/set', '{"state": "ON"}', 2, False),
|
self.assertEqual(('test_light_rgb/set', '{"state": "ON"}', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_ON, state.state)
|
self.assertEqual(STATE_ON, state.state)
|
||||||
|
|
||||||
|
@ -180,7 +180,7 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/set', '{"state": "OFF"}', 2, False),
|
self.assertEqual(('test_light_rgb/set', '{"state": "OFF"}', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_OFF, state.state)
|
self.assertEqual(STATE_OFF, state.state)
|
||||||
|
|
||||||
|
@ -189,11 +189,11 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
# Get the sent message
|
# Get the sent message
|
||||||
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
|
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
|
||||||
self.assertEqual(50, message_json["brightness"])
|
self.assertEqual(50, message_json["brightness"])
|
||||||
self.assertEqual(75, message_json["color"]["r"])
|
self.assertEqual(75, message_json["color"]["r"])
|
||||||
self.assertEqual(75, message_json["color"]["g"])
|
self.assertEqual(75, message_json["color"]["g"])
|
||||||
|
@ -228,11 +228,11 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
# Get the sent message
|
# Get the sent message
|
||||||
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
|
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
|
||||||
self.assertEqual(5, message_json["flash"])
|
self.assertEqual(5, message_json["flash"])
|
||||||
self.assertEqual("ON", message_json["state"])
|
self.assertEqual("ON", message_json["state"])
|
||||||
|
|
||||||
|
@ -240,11 +240,11 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
# Get the sent message
|
# Get the sent message
|
||||||
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
|
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
|
||||||
self.assertEqual(15, message_json["flash"])
|
self.assertEqual(15, message_json["flash"])
|
||||||
self.assertEqual("ON", message_json["state"])
|
self.assertEqual("ON", message_json["state"])
|
||||||
|
|
||||||
|
@ -268,11 +268,11 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
# Get the sent message
|
# Get the sent message
|
||||||
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
|
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
|
||||||
self.assertEqual(10, message_json["transition"])
|
self.assertEqual(10, message_json["transition"])
|
||||||
self.assertEqual("ON", message_json["state"])
|
self.assertEqual("ON", message_json["state"])
|
||||||
|
|
||||||
|
@ -281,11 +281,11 @@ class TestLightMQTTJSON(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
# Get the sent message
|
# Get the sent message
|
||||||
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
|
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
|
||||||
self.assertEqual(10, message_json["transition"])
|
self.assertEqual(10, message_json["transition"])
|
||||||
self.assertEqual("OFF", message_json["state"])
|
self.assertEqual("OFF", message_json["state"])
|
||||||
|
|
||||||
|
|
|
@ -196,7 +196,7 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/set', 'on,,--', 2, False),
|
self.assertEqual(('test_light_rgb/set', 'on,,--', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_ON, state.state)
|
self.assertEqual(STATE_ON, state.state)
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('test_light_rgb/set', 'off', 2, False),
|
self.assertEqual(('test_light_rgb/set', 'off', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('light.test')
|
state = self.hass.states.get('light.test')
|
||||||
self.assertEqual(STATE_OFF, state.state)
|
self.assertEqual(STATE_OFF, state.state)
|
||||||
|
|
||||||
|
@ -215,12 +215,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
|
|
||||||
# check the payload
|
# check the payload
|
||||||
payload = self.mock_publish.mock_calls[-1][1][1]
|
payload = self.mock_publish.mock_calls[-2][1][1]
|
||||||
self.assertEqual('on,50,75-75-75', payload)
|
self.assertEqual('on,50,75-75-75', payload)
|
||||||
|
|
||||||
# check the state
|
# check the state
|
||||||
|
@ -253,12 +253,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
|
|
||||||
# check the payload
|
# check the payload
|
||||||
payload = self.mock_publish.mock_calls[-1][1][1]
|
payload = self.mock_publish.mock_calls[-2][1][1]
|
||||||
self.assertEqual('on,short', payload)
|
self.assertEqual('on,short', payload)
|
||||||
|
|
||||||
# long flash
|
# long flash
|
||||||
|
@ -266,12 +266,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
|
|
||||||
# check the payload
|
# check the payload
|
||||||
payload = self.mock_publish.mock_calls[-1][1][1]
|
payload = self.mock_publish.mock_calls[-2][1][1]
|
||||||
self.assertEqual('on,long', payload)
|
self.assertEqual('on,long', payload)
|
||||||
|
|
||||||
def test_transition(self):
|
def test_transition(self):
|
||||||
|
@ -296,12 +296,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
|
|
||||||
# check the payload
|
# check the payload
|
||||||
payload = self.mock_publish.mock_calls[-1][1][1]
|
payload = self.mock_publish.mock_calls[-2][1][1]
|
||||||
self.assertEqual('on,10', payload)
|
self.assertEqual('on,10', payload)
|
||||||
|
|
||||||
# transition off
|
# transition off
|
||||||
|
@ -309,12 +309,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual('test_light_rgb/set',
|
self.assertEqual('test_light_rgb/set',
|
||||||
self.mock_publish.mock_calls[-1][1][0])
|
self.mock_publish.mock_calls[-2][1][0])
|
||||||
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
|
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
|
||||||
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
|
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
|
||||||
|
|
||||||
# check the payload
|
# check the payload
|
||||||
payload = self.mock_publish.mock_calls[-1][1][1]
|
payload = self.mock_publish.mock_calls[-2][1][1]
|
||||||
self.assertEqual('off,4', payload)
|
self.assertEqual('off,4', payload)
|
||||||
|
|
||||||
def test_invalid_values(self): \
|
def test_invalid_values(self): \
|
||||||
|
|
|
@ -73,7 +73,7 @@ class TestLockMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'LOCK', 2, False),
|
self.assertEqual(('command-topic', 'LOCK', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('lock.test')
|
state = self.hass.states.get('lock.test')
|
||||||
self.assertEqual(STATE_LOCKED, state.state)
|
self.assertEqual(STATE_LOCKED, state.state)
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ class TestLockMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'UNLOCK', 2, False),
|
self.assertEqual(('command-topic', 'UNLOCK', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('lock.test')
|
state = self.hass.states.get('lock.test')
|
||||||
self.assertEqual(STATE_UNLOCKED, state.state)
|
self.assertEqual(STATE_UNLOCKED, state.state)
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,10 @@ def test_subscribing_config_topic(hass, mqtt_mock):
|
||||||
"""Test setting up discovery."""
|
"""Test setting up discovery."""
|
||||||
hass_config = {}
|
hass_config = {}
|
||||||
discovery_topic = 'homeassistant'
|
discovery_topic = 'homeassistant'
|
||||||
async_start(hass, discovery_topic, hass_config)
|
yield from async_start(hass, discovery_topic, hass_config)
|
||||||
assert mqtt_mock.subscribe.called
|
|
||||||
call_args = mqtt_mock.subscribe.mock_calls[0][1]
|
assert mqtt_mock.async_subscribe.called
|
||||||
|
call_args = mqtt_mock.async_subscribe.mock_calls[0][1]
|
||||||
assert call_args[0] == discovery_topic + '/#'
|
assert call_args[0] == discovery_topic + '/#'
|
||||||
assert call_args[1] == 0
|
assert call_args[1] == 0
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ def test_subscribing_config_topic(hass, mqtt_mock):
|
||||||
def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
|
def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
|
||||||
"""Test sending in invalid JSON."""
|
"""Test sending in invalid JSON."""
|
||||||
mock_load_platform.return_value = mock_coro()
|
mock_load_platform.return_value = mock_coro()
|
||||||
async_start(hass, 'homeassistant', {})
|
yield from async_start(hass, 'homeassistant', {})
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config',
|
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config',
|
||||||
'{}')
|
'{}')
|
||||||
|
@ -37,7 +38,7 @@ def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
|
||||||
def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
|
def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
|
||||||
"""Test sending in invalid JSON."""
|
"""Test sending in invalid JSON."""
|
||||||
mock_load_platform.return_value = mock_coro()
|
mock_load_platform.return_value = mock_coro()
|
||||||
async_start(hass, 'homeassistant', {})
|
yield from async_start(hass, 'homeassistant', {})
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
|
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
|
||||||
'not json')
|
'not json')
|
||||||
|
@ -51,7 +52,7 @@ def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
|
||||||
def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
|
def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
|
||||||
"""Test sending in invalid JSON."""
|
"""Test sending in invalid JSON."""
|
||||||
mock_load_platform.return_value = mock_coro()
|
mock_load_platform.return_value = mock_coro()
|
||||||
async_start(hass, 'homeassistant', {})
|
yield from async_start(hass, 'homeassistant', {})
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}')
|
async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}')
|
||||||
yield from hass.async_block_till_done()
|
yield from hass.async_block_till_done()
|
||||||
|
@ -62,7 +63,7 @@ def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_correct_config_discovery(hass, mqtt_mock, caplog):
|
def test_correct_config_discovery(hass, mqtt_mock, caplog):
|
||||||
"""Test sending in invalid JSON."""
|
"""Test sending in invalid JSON."""
|
||||||
async_start(hass, 'homeassistant', {})
|
yield from async_start(hass, 'homeassistant', {})
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
|
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
|
||||||
'{ "name": "Beer" }')
|
'{ "name": "Beer" }')
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""The tests for the MQTT component."""
|
"""The tests for the MQTT component."""
|
||||||
from collections import namedtuple
|
import asyncio
|
||||||
|
from collections import namedtuple, OrderedDict
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import socket
|
import socket
|
||||||
|
@ -7,14 +8,29 @@ import socket
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.bootstrap import setup_component
|
from homeassistant.bootstrap import setup_component, async_setup_component
|
||||||
import homeassistant.components.mqtt as mqtt
|
import homeassistant.components.mqtt as mqtt
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START,
|
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START,
|
||||||
EVENT_HOMEASSISTANT_STOP)
|
EVENT_HOMEASSISTANT_STOP)
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message)
|
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def mock_mqtt_client(hass, config=None):
|
||||||
|
"""Mock the MQTT paho client."""
|
||||||
|
if config is None:
|
||||||
|
config = {
|
||||||
|
mqtt.CONF_BROKER: 'mock-broker'
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch('paho.mqtt.client.Client') as mock_client:
|
||||||
|
yield from async_setup_component(hass, mqtt.DOMAIN, {
|
||||||
|
mqtt.DOMAIN: config
|
||||||
|
})
|
||||||
|
return mock_client()
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -40,7 +56,7 @@ class TestMQTT(unittest.TestCase):
|
||||||
""""Test if client start on HA launch."""
|
""""Test if client start on HA launch."""
|
||||||
self.hass.bus.fire(EVENT_HOMEASSISTANT_START)
|
self.hass.bus.fire(EVENT_HOMEASSISTANT_START)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertTrue(mqtt.MQTT_CLIENT.start.called)
|
self.assertTrue(self.hass.data['mqtt'].async_start.called)
|
||||||
|
|
||||||
def test_client_stops_on_home_assistant_start(self):
|
def test_client_stops_on_home_assistant_start(self):
|
||||||
"""Test if client stops on HA launch."""
|
"""Test if client stops on HA launch."""
|
||||||
|
@ -48,7 +64,7 @@ class TestMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
|
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertTrue(mqtt.MQTT_CLIENT.stop.called)
|
self.assertTrue(self.hass.data['mqtt'].async_stop.called)
|
||||||
|
|
||||||
@mock.patch('paho.mqtt.client.Client')
|
@mock.patch('paho.mqtt.client.Client')
|
||||||
def test_setup_fails_if_no_connect_broker(self, _):
|
def test_setup_fails_if_no_connect_broker(self, _):
|
||||||
|
@ -69,14 +85,17 @@ class TestMQTT(unittest.TestCase):
|
||||||
"""Test setting up embedded server with no config."""
|
"""Test setting up embedded server with no config."""
|
||||||
client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1')
|
client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1')
|
||||||
|
|
||||||
with mock.patch('homeassistant.components.mqtt.server.start',
|
with mock.patch('homeassistant.components.mqtt.server.async_start',
|
||||||
return_value=(True, client_config)) as _start:
|
return_value=mock_coro(
|
||||||
|
return_value=(True, client_config))
|
||||||
|
) as _start:
|
||||||
self.hass.config.components = set()
|
self.hass.config.components = set()
|
||||||
assert setup_component(self.hass, mqtt.DOMAIN,
|
assert setup_component(self.hass, mqtt.DOMAIN,
|
||||||
{mqtt.DOMAIN: {}})
|
{mqtt.DOMAIN: {}})
|
||||||
assert _start.call_count == 1
|
assert _start.call_count == 1
|
||||||
|
|
||||||
# Test with `embedded: None`
|
# Test with `embedded: None`
|
||||||
|
_start.return_value = mock_coro(return_value=(True, client_config))
|
||||||
self.hass.config.components = set()
|
self.hass.config.components = set()
|
||||||
assert setup_component(self.hass, mqtt.DOMAIN,
|
assert setup_component(self.hass, mqtt.DOMAIN,
|
||||||
{mqtt.DOMAIN: {'embedded': None}})
|
{mqtt.DOMAIN: {'embedded': None}})
|
||||||
|
@ -105,7 +124,7 @@ class TestMQTT(unittest.TestCase):
|
||||||
ATTR_SERVICE: mqtt.SERVICE_PUBLISH
|
ATTR_SERVICE: mqtt.SERVICE_PUBLISH
|
||||||
})
|
})
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertTrue(not mqtt.MQTT_CLIENT.publish.called)
|
self.assertTrue(not self.hass.data['mqtt'].async_publish.called)
|
||||||
|
|
||||||
def test_service_call_with_template_payload_renders_template(self):
|
def test_service_call_with_template_payload_renders_template(self):
|
||||||
"""Test the service call with rendered template.
|
"""Test the service call with rendered template.
|
||||||
|
@ -114,8 +133,9 @@ class TestMQTT(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}")
|
mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}")
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertTrue(mqtt.MQTT_CLIENT.publish.called)
|
self.assertTrue(self.hass.data['mqtt'].async_publish.called)
|
||||||
self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][1], "2")
|
self.assertEqual(
|
||||||
|
self.hass.data['mqtt'].async_publish.call_args[0][1], "2")
|
||||||
|
|
||||||
def test_service_call_with_payload_doesnt_render_template(self):
|
def test_service_call_with_payload_doesnt_render_template(self):
|
||||||
"""Test the service call with unrendered template.
|
"""Test the service call with unrendered template.
|
||||||
|
@ -129,7 +149,7 @@ class TestMQTT(unittest.TestCase):
|
||||||
mqtt.ATTR_PAYLOAD: payload,
|
mqtt.ATTR_PAYLOAD: payload,
|
||||||
mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template
|
mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template
|
||||||
}, blocking=True)
|
}, blocking=True)
|
||||||
self.assertFalse(mqtt.MQTT_CLIENT.publish.called)
|
self.assertFalse(self.hass.data['mqtt'].async_publish.called)
|
||||||
|
|
||||||
def test_service_call_with_ascii_qos_retain_flags(self):
|
def test_service_call_with_ascii_qos_retain_flags(self):
|
||||||
"""Test the service call with args that can be misinterpreted.
|
"""Test the service call with args that can be misinterpreted.
|
||||||
|
@ -142,9 +162,10 @@ class TestMQTT(unittest.TestCase):
|
||||||
mqtt.ATTR_QOS: '2',
|
mqtt.ATTR_QOS: '2',
|
||||||
mqtt.ATTR_RETAIN: 'no'
|
mqtt.ATTR_RETAIN: 'no'
|
||||||
}, blocking=True)
|
}, blocking=True)
|
||||||
self.assertTrue(mqtt.MQTT_CLIENT.publish.called)
|
self.assertTrue(self.hass.data['mqtt'].async_publish.called)
|
||||||
self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][2], 2)
|
self.assertEqual(
|
||||||
self.assertFalse(mqtt.MQTT_CLIENT.publish.call_args[0][3])
|
self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
|
||||||
|
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
|
||||||
|
|
||||||
def test_subscribe_topic(self):
|
def test_subscribe_topic(self):
|
||||||
"""Test the subscription of a topic."""
|
"""Test the subscription of a topic."""
|
||||||
|
@ -231,15 +252,12 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
def setUp(self): # pylint: disable=invalid-name
|
def setUp(self): # pylint: disable=invalid-name
|
||||||
"""Setup things to be run when tests are started."""
|
"""Setup things to be run when tests are started."""
|
||||||
self.hass = get_test_home_assistant()
|
self.hass = get_test_home_assistant()
|
||||||
# mock_mqtt_component(self.hass)
|
|
||||||
|
|
||||||
with mock.patch('paho.mqtt.client.Client'):
|
with mock.patch('paho.mqtt.client.Client'):
|
||||||
self.hass.config.components = set()
|
self.hass.config.components = set()
|
||||||
assert setup_component(self.hass, mqtt.DOMAIN, {
|
assert setup_component(self.hass, mqtt.DOMAIN, {
|
||||||
mqtt.DOMAIN: {
|
mqtt.DOMAIN: {
|
||||||
mqtt.CONF_BROKER: 'mock-broker',
|
mqtt.CONF_BROKER: 'mock-broker',
|
||||||
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
|
|
||||||
mqtt.ATTR_PAYLOAD: 'birth'}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -261,7 +279,8 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
||||||
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
|
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
|
||||||
|
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_message(None, {'hass': self.hass}, message)
|
self.hass.data['mqtt']._mqtt_on_message(
|
||||||
|
None, {'hass': self.hass}, message)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(1, len(calls))
|
self.assertEqual(1, len(calls))
|
||||||
|
@ -273,68 +292,36 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
def test_mqtt_failed_connection_results_in_disconnect(self):
|
def test_mqtt_failed_connection_results_in_disconnect(self):
|
||||||
"""Test if connection failure leads to disconnect."""
|
"""Test if connection failure leads to disconnect."""
|
||||||
for result_code in range(1, 6):
|
for result_code in range(1, 6):
|
||||||
mqtt.MQTT_CLIENT._mqttc = mock.MagicMock()
|
self.hass.data['mqtt']._mqttc = mock.MagicMock()
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_connect(None, {'topics': {}}, 0,
|
self.hass.data['mqtt']._mqtt_on_connect(
|
||||||
result_code)
|
None, {'topics': {}}, 0, result_code)
|
||||||
self.assertTrue(mqtt.MQTT_CLIENT._mqttc.disconnect.called)
|
self.assertTrue(self.hass.data['mqtt']._mqttc.disconnect.called)
|
||||||
|
|
||||||
def test_mqtt_subscribes_topics_on_connect(self):
|
|
||||||
"""Test subscription to topic on connect."""
|
|
||||||
from collections import OrderedDict
|
|
||||||
prev_topics = OrderedDict()
|
|
||||||
prev_topics['topic/test'] = 1,
|
|
||||||
prev_topics['home/sensor'] = 2,
|
|
||||||
prev_topics['still/pending'] = None
|
|
||||||
|
|
||||||
mqtt.MQTT_CLIENT.topics = prev_topics
|
|
||||||
mqtt.MQTT_CLIENT.progress = {1: 'still/pending'}
|
|
||||||
# Return values for subscribe calls (rc, mid)
|
|
||||||
mqtt.MQTT_CLIENT._mqttc.subscribe.side_effect = ((0, 2), (0, 3))
|
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0)
|
|
||||||
self.assertFalse(mqtt.MQTT_CLIENT._mqttc.disconnect.called)
|
|
||||||
|
|
||||||
expected = [(topic, qos) for topic, qos in prev_topics.items()
|
|
||||||
if qos is not None]
|
|
||||||
self.assertEqual(
|
|
||||||
expected,
|
|
||||||
[call[1] for call in mqtt.MQTT_CLIENT._mqttc.subscribe.mock_calls])
|
|
||||||
self.assertEqual({
|
|
||||||
1: 'still/pending',
|
|
||||||
2: 'topic/test',
|
|
||||||
3: 'home/sensor',
|
|
||||||
}, mqtt.MQTT_CLIENT.progress)
|
|
||||||
|
|
||||||
def test_mqtt_birth_message_on_connect(self): \
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
"""Test birth message on connect."""
|
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0)
|
|
||||||
mqtt.MQTT_CLIENT._mqttc.publish.assert_called_with('birth', 'birth', 0,
|
|
||||||
False)
|
|
||||||
|
|
||||||
def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
|
def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
|
||||||
"""Test the disconnect tries."""
|
"""Test the disconnect tries."""
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 0)
|
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
|
||||||
self.assertFalse(mqtt.MQTT_CLIENT._mqttc.reconnect.called)
|
self.assertFalse(self.hass.data['mqtt']._mqttc.reconnect.called)
|
||||||
|
|
||||||
@mock.patch('homeassistant.components.mqtt.time.sleep')
|
@mock.patch('homeassistant.components.mqtt.time.sleep')
|
||||||
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
|
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
|
||||||
"""Test the re-connect tries."""
|
"""Test the re-connect tries."""
|
||||||
mqtt.MQTT_CLIENT.topics = {
|
self.hass.data['mqtt'].topics = {
|
||||||
'test/topic': 1,
|
'test/topic': 1,
|
||||||
'test/progress': None
|
'test/progress': None
|
||||||
}
|
}
|
||||||
mqtt.MQTT_CLIENT.progress = {
|
self.hass.data['mqtt'].progress = {
|
||||||
1: 'test/progress'
|
1: 'test/progress'
|
||||||
}
|
}
|
||||||
mqtt.MQTT_CLIENT._mqttc.reconnect.side_effect = [1, 1, 1, 0]
|
self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0]
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 1)
|
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1)
|
||||||
self.assertTrue(mqtt.MQTT_CLIENT._mqttc.reconnect.called)
|
self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called)
|
||||||
self.assertEqual(4, len(mqtt.MQTT_CLIENT._mqttc.reconnect.mock_calls))
|
self.assertEqual(
|
||||||
|
4, len(self.hass.data['mqtt']._mqttc.reconnect.mock_calls))
|
||||||
self.assertEqual([1, 2, 4],
|
self.assertEqual([1, 2, 4],
|
||||||
[call[1][0] for call in mock_sleep.mock_calls])
|
[call[1][0] for call in mock_sleep.mock_calls])
|
||||||
|
|
||||||
self.assertEqual({'test/topic': 1}, mqtt.MQTT_CLIENT.topics)
|
self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics)
|
||||||
self.assertEqual({}, mqtt.MQTT_CLIENT.progress)
|
self.assertEqual({}, self.hass.data['mqtt'].progress)
|
||||||
|
|
||||||
def test_invalid_mqtt_topics(self):
|
def test_invalid_mqtt_topics(self):
|
||||||
"""Test invalid topics."""
|
"""Test invalid topics."""
|
||||||
|
@ -356,7 +343,7 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
|
||||||
message = MQTTMessage(topic, 1, payload)
|
message = MQTTMessage(topic, 1, payload)
|
||||||
with self.assertLogs(level='ERROR') as test_handle:
|
with self.assertLogs(level='ERROR') as test_handle:
|
||||||
mqtt.MQTT_CLIENT._mqtt_on_message(
|
self.hass.data['mqtt']._mqtt_on_message(
|
||||||
None,
|
None,
|
||||||
{'hass': self.hass},
|
{'hass': self.hass},
|
||||||
message)
|
message)
|
||||||
|
@ -365,3 +352,47 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
"ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode "
|
"ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode "
|
||||||
"payload from MQTT topic: %s, Payload: " % topic,
|
"payload from MQTT topic: %s, Payload: " % topic,
|
||||||
test_handle.output[0])
|
test_handle.output[0])
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_birth_message(hass):
|
||||||
|
"""Test sending birth message."""
|
||||||
|
mqtt_client = yield from mock_mqtt_client(hass, {
|
||||||
|
mqtt.CONF_BROKER: 'mock-broker',
|
||||||
|
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
|
||||||
|
mqtt.ATTR_PAYLOAD: 'birth'}
|
||||||
|
})
|
||||||
|
calls = []
|
||||||
|
mqtt_client.publish = lambda *args: calls.append(args)
|
||||||
|
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
|
||||||
|
yield from hass.async_block_till_done()
|
||||||
|
assert calls[-1] == ('birth', 'birth', 0, False)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_mqtt_subscribes_topics_on_connect(hass):
|
||||||
|
"""Test subscription to topic on connect."""
|
||||||
|
mqtt_client = yield from mock_mqtt_client(hass)
|
||||||
|
|
||||||
|
prev_topics = OrderedDict()
|
||||||
|
prev_topics['topic/test'] = 1,
|
||||||
|
prev_topics['home/sensor'] = 2,
|
||||||
|
prev_topics['still/pending'] = None
|
||||||
|
|
||||||
|
hass.data['mqtt'].topics = prev_topics
|
||||||
|
hass.data['mqtt'].progress = {1: 'still/pending'}
|
||||||
|
|
||||||
|
# Return values for subscribe calls (rc, mid)
|
||||||
|
mqtt_client.subscribe.side_effect = ((0, 2), (0, 3))
|
||||||
|
|
||||||
|
hass.add_job = mock.MagicMock()
|
||||||
|
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
|
||||||
|
|
||||||
|
yield from hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert not mqtt_client.disconnect.called
|
||||||
|
|
||||||
|
expected = [(topic, qos) for topic, qos in prev_topics.items()
|
||||||
|
if qos is not None]
|
||||||
|
|
||||||
|
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected
|
||||||
|
|
|
@ -4,7 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
|
||||||
from homeassistant.bootstrap import setup_component
|
from homeassistant.bootstrap import setup_component
|
||||||
import homeassistant.components.mqtt as mqtt
|
import homeassistant.components.mqtt as mqtt
|
||||||
|
|
||||||
from tests.common import get_test_home_assistant
|
from tests.common import get_test_home_assistant, mock_coro
|
||||||
|
|
||||||
|
|
||||||
class TestMQTT:
|
class TestMQTT:
|
||||||
|
@ -21,9 +21,8 @@ class TestMQTT:
|
||||||
|
|
||||||
@patch('passlib.apps.custom_app_context', Mock(return_value=''))
|
@patch('passlib.apps.custom_app_context', Mock(return_value=''))
|
||||||
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
|
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
|
||||||
@patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe',
|
|
||||||
Mock(return_value=MagicMock()))
|
|
||||||
@patch('hbmqtt.broker.Broker', Mock(return_value=MagicMock()))
|
@patch('hbmqtt.broker.Broker', Mock(return_value=MagicMock()))
|
||||||
|
@patch('hbmqtt.broker.Broker.start', Mock(return_value=mock_coro()))
|
||||||
@patch('homeassistant.components.mqtt.MQTT')
|
@patch('homeassistant.components.mqtt.MQTT')
|
||||||
def test_creating_config_with_http_pass(self, mock_mqtt):
|
def test_creating_config_with_http_pass(self, mock_mqtt):
|
||||||
"""Test if the MQTT server gets started and subscribe/publish msg."""
|
"""Test if the MQTT server gets started and subscribe/publish msg."""
|
||||||
|
@ -46,7 +45,7 @@ class TestMQTT:
|
||||||
assert mock_mqtt.mock_calls[0][1][6] is None
|
assert mock_mqtt.mock_calls[0][1][6] is None
|
||||||
|
|
||||||
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
|
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
|
||||||
@patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe')
|
@patch('hbmqtt.broker.Broker.start', return_value=mock_coro())
|
||||||
def test_broker_config_fails(self, mock_run):
|
def test_broker_config_fails(self, mock_run):
|
||||||
"""Test if the MQTT component fails if server fails."""
|
"""Test if the MQTT component fails if server fails."""
|
||||||
from hbmqtt.broker import BrokerException
|
from hbmqtt.broker import BrokerException
|
||||||
|
|
|
@ -72,7 +72,7 @@ class TestSensorMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'beer on', 2, False),
|
self.assertEqual(('command-topic', 'beer on', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('switch.test')
|
state = self.hass.states.get('switch.test')
|
||||||
self.assertEqual(STATE_ON, state.state)
|
self.assertEqual(STATE_ON, state.state)
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class TestSensorMQTT(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(('command-topic', 'beer off', 2, False),
|
self.assertEqual(('command-topic', 'beer off', 2, False),
|
||||||
self.mock_publish.mock_calls[-1][1])
|
self.mock_publish.mock_calls[-2][1])
|
||||||
state = self.hass.states.get('switch.test')
|
state = self.hass.states.get('switch.test')
|
||||||
self.assertEqual(STATE_OFF, state.state)
|
self.assertEqual(STATE_OFF, state.state)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue