diff --git a/homeassistant/components/binary_sensor/mqtt.py b/homeassistant/components/binary_sensor/mqtt.py index 28d9566b2ab..8c8beaddb6e 100644 --- a/homeassistant/components/binary_sensor/mqtt.py +++ b/homeassistant/components/binary_sensor/mqtt.py @@ -36,7 +36,10 @@ PLATFORM_SCHEMA = mqtt.MQTT_RO_PLATFORM_SCHEMA.extend({ # pylint: disable=unused-argument def setup_platform(hass, config, add_devices, discovery_info=None): - """Setup the MQTT binary sensor.""" + """Set up the MQTT binary sensor.""" + if discovery_info is not None: + config = PLATFORM_SCHEMA(discovery_info) + value_template = config.get(CONF_VALUE_TEMPLATE) if value_template is not None: value_template.hass = hass diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index ad4cce15cf3..e880be177e8 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -4,7 +4,6 @@ Support for MQTT message handling. For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ -import asyncio import logging import os import socket @@ -12,6 +11,7 @@ import time import voluptuous as vol +from homeassistant.core import callback from homeassistant.bootstrap import prepare_setup_platform from homeassistant.config import load_yaml_config_file from homeassistant.exceptions import HomeAssistantError @@ -36,6 +36,8 @@ REQUIREMENTS = ['paho-mqtt==1.2'] CONF_EMBEDDED = 'embedded' CONF_BROKER = 'broker' CONF_CLIENT_ID = 'client_id' +CONF_DISCOVERY = 'discovery' +CONF_DISCOVERY_PREFIX = 'discovery_prefix' CONF_KEEPALIVE = 'keepalive' CONF_CERTIFICATE = 'certificate' CONF_CLIENT_KEY = 'client_key' @@ -58,6 +60,8 @@ DEFAULT_KEEPALIVE = 60 DEFAULT_QOS = 0 DEFAULT_RETAIN = False DEFAULT_PROTOCOL = PROTOCOL_311 +DEFAULT_DISCOVERY = False +DEFAULT_DISCOVERY_PREFIX = 'homeassistant' ATTR_TOPIC = 'topic' ATTR_PAYLOAD = 'payload' @@ -70,7 +74,8 @@ MAX_RECONNECT_WAIT = 300 # seconds def valid_subscribe_topic(value, invalid_chars='\0'): """Validate that we can subscribe using this MQTT topic.""" - if isinstance(value, str) and all(c not in value for c in invalid_chars): + value = cv.string(value) + if all(c not in value for c in invalid_chars): return vol.Length(min=1, max=65535)(value) raise vol.Invalid('Invalid MQTT topic name') @@ -80,6 +85,11 @@ def valid_publish_topic(value): return valid_subscribe_topic(value, invalid_chars='#+\0') +def valid_discovery_topic(value): + """Validate a discovery topic.""" + return valid_subscribe_topic(value, invalid_chars='#+\0/') + + _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) CLIENT_KEY_AUTH_MSG = 'client_key and client_cert must both be present in ' \ @@ -111,7 +121,10 @@ CONFIG_SCHEMA = vol.Schema({ vol.All(cv.string, vol.In([PROTOCOL_31, PROTOCOL_311])), vol.Optional(CONF_EMBEDDED): HBMQTT_CONFIG_SCHEMA, vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, - vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA + vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, + vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean, + vol.Optional(CONF_DISCOVERY_PREFIX, + default=DEFAULT_DISCOVERY_PREFIX): valid_discovery_topic, }), }, extra=vol.ALLOW_EXTRA) @@ -170,15 +183,16 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): hass.services.call(DOMAIN, SERVICE_PUBLISH, data) -def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS): +@callback +def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): """Subscribe to an MQTT topic.""" - @asyncio.coroutine + @callback def mqtt_topic_subscriber(event): """Match subscribed MQTT topic.""" if not _match_topic(topic, event.data[ATTR_TOPIC]): return - hass.async_run_job(callback, event.data[ATTR_TOPIC], + hass.async_run_job(msg_callback, event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED, @@ -213,6 +227,21 @@ def _setup_server(hass, config): return success and broker_config +def _setup_discovery(hass, config): + """Try to start the discovery of MQTT devices.""" + conf = config.get(DOMAIN, {}) + + discovery = prepare_setup_platform(hass, config, DOMAIN, 'discovery') + + if discovery is None: + _LOGGER.error("Unable to load MQTT discovery") + return None + + success = discovery.start(hass, conf[CONF_DISCOVERY_PREFIX], config) + + return success + + def setup(hass, config): """Start the MQTT protocol service.""" conf = config.get(DOMAIN, {}) @@ -301,6 +330,9 @@ def setup(hass, config): descriptions.get(SERVICE_PUBLISH), schema=MQTT_PUBLISH_SCHEMA) + if conf.get(CONF_DISCOVERY): + _setup_discovery(hass, config) + return True diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py new file mode 100644 index 00000000000..ca2d37bbbba --- /dev/null +++ b/homeassistant/components/mqtt/discovery.py @@ -0,0 +1,62 @@ +""" +Support for MQTT discovery. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/mqtt/#discovery +""" +import asyncio +import json +import logging +import re + +from homeassistant.core import callback +import homeassistant.components.mqtt as mqtt +from homeassistant.components.mqtt import DOMAIN +from homeassistant.helpers.discovery import async_load_platform +from homeassistant.const import CONF_PLATFORM +from homeassistant.components.mqtt import CONF_STATE_TOPIC + +_LOGGER = logging.getLogger(__name__) + +TOPIC_MATCHER = re.compile( + r'homeassistant/(?P\w+)/(?P\w+)/config') +SUPPORTED_COMPONENTS = ['binary_sensor'] + + +@callback +def async_start(hass, discovery_topic, hass_config): + """Initialization of MQTT Discovery.""" + @asyncio.coroutine + def async_device_message_received(topic, payload, qos): + """Process the received message.""" + match = TOPIC_MATCHER.match(topic) + + if not match: + return + + component, object_id = match.groups() + + try: + payload = json.loads(payload) + except ValueError: + _LOGGER.warning( + "Unable to parse JSON %s: %s", object_id, payload) + return + + if component not in SUPPORTED_COMPONENTS: + _LOGGER.warning("Component %s is not supported", component) + return + + payload = dict(payload) + payload[CONF_PLATFORM] = 'mqtt' + if CONF_STATE_TOPIC not in payload: + payload[CONF_STATE_TOPIC] = '{}/{}/{}/state'.format( + discovery_topic, component, object_id) + + yield from async_load_platform( + hass, component, DOMAIN, payload, hass_config) + + mqtt.async_subscribe(hass, discovery_topic + '/#', + async_device_message_received, 0) + + return True diff --git a/tests/common.py b/tests/common.py index 98a3102edf7..5ebca8640cc 100644 --- a/tests/common.py +++ b/tests/common.py @@ -26,6 +26,7 @@ from homeassistant.components import sun, mqtt from homeassistant.components.http.auth import auth_middleware from homeassistant.components.http.const import ( KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS) +from homeassistant.util.async import run_callback_threadsafe _TEST_INSTANCE_PORT = SERVER_PORT _LOGGER = logging.getLogger(__name__) @@ -147,15 +148,22 @@ def mock_service(hass, domain, service): return calls -def fire_mqtt_message(hass, topic, payload, qos=0): +@ha.callback +def async_fire_mqtt_message(hass, topic, payload, qos=0): """Fire the MQTT message.""" - hass.bus.fire(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, { + hass.bus.async_fire(mqtt.EVENT_MQTT_MESSAGE_RECEIVED, { mqtt.ATTR_TOPIC: topic, mqtt.ATTR_PAYLOAD: payload, mqtt.ATTR_QOS: qos, }) +def fire_mqtt_message(hass, topic, payload, qos=0): + """Fire the MQTT message.""" + run_callback_threadsafe( + hass.loop, async_fire_mqtt_message, hass, topic, payload, qos).result() + + def fire_time_changed(hass, time): """Fire a time changes event.""" hass.bus.fire(EVENT_TIME_CHANGED, {'now': time}) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py new file mode 100644 index 00000000000..bf6fa2f2603 --- /dev/null +++ b/tests/components/mqtt/test_discovery.py @@ -0,0 +1,74 @@ +"""The tests for the MQTT component.""" +import asyncio +from unittest.mock import patch + +from homeassistant.components.mqtt.discovery import async_start + +from tests.common import async_fire_mqtt_message, mock_coro + + +@asyncio.coroutine +def test_subscribing_config_topic(hass, mqtt_mock): + """Test setting up discovery.""" + hass_config = {} + discovery_topic = 'homeassistant' + async_start(hass, discovery_topic, hass_config) + assert mqtt_mock.subscribe.called + call_args = mqtt_mock.subscribe.mock_calls[0][1] + assert call_args[0] == discovery_topic + '/#' + assert call_args[1] == 0 + + +@asyncio.coroutine +@patch('homeassistant.components.mqtt.discovery.async_load_platform') +def test_invalid_topic(mock_load_platform, hass, mqtt_mock): + """Test sending in invalid JSON.""" + mock_load_platform.return_value = mock_coro() + async_start(hass, 'homeassistant', {}) + + async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config', + '{}') + yield from hass.async_block_till_done() + assert not mock_load_platform.called + + +@asyncio.coroutine +@patch('homeassistant.components.mqtt.discovery.async_load_platform') +def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog): + """Test sending in invalid JSON.""" + mock_load_platform.return_value = mock_coro() + async_start(hass, 'homeassistant', {}) + + async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config', + 'not json') + yield from hass.async_block_till_done() + assert 'Unable to parse JSON' in caplog.text + assert not mock_load_platform.called + + +@asyncio.coroutine +@patch('homeassistant.components.mqtt.discovery.async_load_platform') +def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog): + """Test sending in invalid JSON.""" + mock_load_platform.return_value = mock_coro() + async_start(hass, 'homeassistant', {}) + + async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}') + yield from hass.async_block_till_done() + assert 'Component climate is not supported' in caplog.text + assert not mock_load_platform.called + + +@asyncio.coroutine +def test_correct_config_discovery(hass, mqtt_mock, caplog): + """Test sending in invalid JSON.""" + async_start(hass, 'homeassistant', {}) + + async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config', + '{ "name": "Beer" }') + yield from hass.async_block_till_done() + + state = hass.states.get('binary_sensor.beer') + + assert state is not None + assert state.name == 'Beer' diff --git a/tests/conftest.py b/tests/conftest.py index 54f5404d72d..d408ed254f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,14 @@ """Setup some common test helper things.""" import functools import logging +from unittest.mock import patch import pytest import requests_mock as _requests_mock -from homeassistant import util +from homeassistant import util, bootstrap from homeassistant.util import location +from homeassistant.components import mqtt from .common import async_test_home_assistant from .test_util.aiohttp import mock_aiohttp_client @@ -58,3 +60,18 @@ def aioclient_mock(): """Fixture to mock aioclient calls.""" with mock_aiohttp_client() as mock_session: yield mock_session + + +@pytest.fixture +def mqtt_mock(loop, hass): + """Fixture to mock MQTT.""" + with patch('homeassistant.components.mqtt.MQTT') as mock_mqtt: + loop.run_until_complete(bootstrap.async_setup_component( + hass, mqtt.DOMAIN, { + mqtt.DOMAIN: { + mqtt.CONF_BROKER: 'mock-broker', + } + })) + client = mock_mqtt() + client.reset_mock() + return client diff --git a/tests/scripts/test_check_config.py b/tests/scripts/test_check_config.py index 1d0bbbd8dfd..23dde3a8244 100644 --- a/tests/scripts/test_check_config.py +++ b/tests/scripts/test_check_config.py @@ -101,7 +101,13 @@ class TestCheckConfig(unittest.TestCase): res = check_config.check(get_test_config_dir('platform.yaml')) change_yaml_files(res) self.assertDictEqual( - {'mqtt': {'keepalive': 60, 'port': 1883, 'protocol': '3.1.1'}, + {'mqtt': { + 'keepalive': 60, + 'port': 1883, + 'protocol': '3.1.1', + 'discovery': False, + 'discovery_prefix': 'homeassistant', + }, 'light': []}, res['components'] )