diff --git a/homeassistant/components/device_tracker/mqtt.py b/homeassistant/components/device_tracker/mqtt.py index d754156f217..609d8cc713a 100644 --- a/homeassistant/components/device_tracker/mqtt.py +++ b/homeassistant/components/device_tracker/mqtt.py @@ -6,28 +6,26 @@ https://home-assistant.io/components/device_tracker.mqtt/ """ import logging +import voluptuous as vol + import homeassistant.components.mqtt as mqtt -from homeassistant import util +import homeassistant.helpers.config_validation as cv DEPENDENCIES = ['mqtt'] -CONF_QOS = 'qos' CONF_DEVICES = 'devices' -DEFAULT_QOS = 0 - _LOGGER = logging.getLogger(__name__) +PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend({ + vol.Required(CONF_DEVICES): {cv.string: mqtt.valid_subscribe_topic}, +}) + def setup_scanner(hass, config, see): """Setup the MQTT tracker.""" - devices = config.get(CONF_DEVICES) - qos = util.convert(config.get(CONF_QOS), int, DEFAULT_QOS) - - if not isinstance(devices, dict): - _LOGGER.error('Expected %s to be a dict, found %s', CONF_DEVICES, - devices) - return False + devices = config[CONF_DEVICES] + qos = config[mqtt.CONF_QOS] dev_id_lookup = {} diff --git a/tests/components/device_tracker/test_mqtt.py b/tests/components/device_tracker/test_mqtt.py index 7b6024c60f1..139316a35bf 100644 --- a/tests/components/device_tracker/test_mqtt.py +++ b/tests/components/device_tracker/test_mqtt.py @@ -2,6 +2,7 @@ import unittest import os +from homeassistant.bootstrap import _setup_component from homeassistant.components import device_tracker from homeassistant.const import CONF_PLATFORM @@ -31,11 +32,13 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase): topic = '/location/paulus' location = 'work' - self.assertTrue(device_tracker.setup(self.hass, { + self.hass.config.components = ['mqtt', 'zone'] + assert _setup_component(self.hass, device_tracker.DOMAIN, { device_tracker.DOMAIN: { CONF_PLATFORM: 'mqtt', 'devices': {dev_id: topic} - }})) + } + }) fire_mqtt_message(self.hass, topic, location) self.hass.pool.block_till_done() self.assertEqual(location, self.hass.states.get(enttiy_id).state)