allow wildcards in subscription (#12247)
* allow wildcards in subscription * remove whitespaces * make function public * also implement for mqtt_json * avoid mqtt-outside topic matching * add wildcard tests * add not matching wildcard tests * fix not-matching tests
This commit is contained in:
parent
1db4df6d3a
commit
cad9e9a4cb
4 changed files with 175 additions and 34 deletions
|
@ -31,17 +31,14 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
|
||||||
devices = config[CONF_DEVICES]
|
devices = config[CONF_DEVICES]
|
||||||
qos = config[CONF_QOS]
|
qos = config[CONF_QOS]
|
||||||
|
|
||||||
dev_id_lookup = {}
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_tracker_message_received(topic, payload, qos):
|
|
||||||
"""Handle received MQTT message."""
|
|
||||||
hass.async_add_job(
|
|
||||||
async_see(dev_id=dev_id_lookup[topic], location_name=payload))
|
|
||||||
|
|
||||||
for dev_id, topic in devices.items():
|
for dev_id, topic in devices.items():
|
||||||
dev_id_lookup[topic] = dev_id
|
@callback
|
||||||
|
def async_message_received(topic, payload, qos, dev_id=dev_id):
|
||||||
|
"""Handle received MQTT message."""
|
||||||
|
hass.async_add_job(
|
||||||
|
async_see(dev_id=dev_id, location_name=payload))
|
||||||
|
|
||||||
yield from mqtt.async_subscribe(
|
yield from mqtt.async_subscribe(
|
||||||
hass, topic, async_tracker_message_received, qos)
|
hass, topic, async_message_received, qos)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -41,32 +41,26 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
|
||||||
devices = config[CONF_DEVICES]
|
devices = config[CONF_DEVICES]
|
||||||
qos = config[CONF_QOS]
|
qos = config[CONF_QOS]
|
||||||
|
|
||||||
dev_id_lookup = {}
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_tracker_message_received(topic, payload, qos):
|
|
||||||
"""Handle received MQTT message."""
|
|
||||||
dev_id = dev_id_lookup[topic]
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
|
|
||||||
except vol.MultipleInvalid:
|
|
||||||
_LOGGER.error("Skipping update for following data "
|
|
||||||
"because of missing or malformatted data: %s",
|
|
||||||
payload)
|
|
||||||
return
|
|
||||||
except ValueError:
|
|
||||||
_LOGGER.error("Error parsing JSON payload: %s", payload)
|
|
||||||
return
|
|
||||||
|
|
||||||
kwargs = _parse_see_args(dev_id, data)
|
|
||||||
hass.async_add_job(
|
|
||||||
async_see(**kwargs))
|
|
||||||
|
|
||||||
for dev_id, topic in devices.items():
|
for dev_id, topic in devices.items():
|
||||||
dev_id_lookup[topic] = dev_id
|
@callback
|
||||||
|
def async_message_received(topic, payload, qos, dev_id=dev_id):
|
||||||
|
"""Handle received MQTT message."""
|
||||||
|
try:
|
||||||
|
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
|
||||||
|
except vol.MultipleInvalid:
|
||||||
|
_LOGGER.error("Skipping update for following data "
|
||||||
|
"because of missing or malformatted data: %s",
|
||||||
|
payload)
|
||||||
|
return
|
||||||
|
except ValueError:
|
||||||
|
_LOGGER.error("Error parsing JSON payload: %s", payload)
|
||||||
|
return
|
||||||
|
|
||||||
|
kwargs = _parse_see_args(dev_id, data)
|
||||||
|
hass.async_add_job(async_see(**kwargs))
|
||||||
|
|
||||||
yield from mqtt.async_subscribe(
|
yield from mqtt.async_subscribe(
|
||||||
hass, topic, async_tracker_message_received, qos)
|
hass, topic, async_message_received, qos)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -70,3 +70,79 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase):
|
||||||
fire_mqtt_message(self.hass, topic, location)
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
||||||
|
|
||||||
|
def test_single_level_wildcard_topic(self):
|
||||||
|
"""Test single level wildcard topic."""
|
||||||
|
dev_id = 'paulus'
|
||||||
|
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||||
|
subscription = '/location/+/paulus'
|
||||||
|
topic = '/location/room/paulus'
|
||||||
|
location = 'work'
|
||||||
|
|
||||||
|
self.hass.config.components = set(['mqtt', 'zone'])
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
||||||
|
|
||||||
|
def test_multi_level_wildcard_topic(self):
|
||||||
|
"""Test multi level wildcard topic."""
|
||||||
|
dev_id = 'paulus'
|
||||||
|
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||||
|
subscription = '/location/#'
|
||||||
|
topic = '/location/room/paulus'
|
||||||
|
location = 'work'
|
||||||
|
|
||||||
|
self.hass.config.components = set(['mqtt', 'zone'])
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
||||||
|
|
||||||
|
def test_single_level_wildcard_topic_not_matching(self):
|
||||||
|
"""Test not matching single level wildcard topic."""
|
||||||
|
dev_id = 'paulus'
|
||||||
|
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||||
|
subscription = '/location/+/paulus'
|
||||||
|
topic = '/location/paulus'
|
||||||
|
location = 'work'
|
||||||
|
|
||||||
|
self.hass.config.components = set(['mqtt', 'zone'])
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertIsNone(self.hass.states.get(entity_id))
|
||||||
|
|
||||||
|
def test_multi_level_wildcard_topic_not_matching(self):
|
||||||
|
"""Test not matching multi level wildcard topic."""
|
||||||
|
dev_id = 'paulus'
|
||||||
|
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||||
|
subscription = '/location/#'
|
||||||
|
topic = '/somewhere/room/paulus'
|
||||||
|
location = 'work'
|
||||||
|
|
||||||
|
self.hass.config.components = set(['mqtt', 'zone'])
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertIsNone(self.hass.states.get(entity_id))
|
||||||
|
|
|
@ -123,3 +123,77 @@ class TestComponentsDeviceTrackerJSONMQTT(unittest.TestCase):
|
||||||
"Skipping update for following data because of missing "
|
"Skipping update for following data because of missing "
|
||||||
"or malformatted data: {\"longitude\": 2.0}",
|
"or malformatted data: {\"longitude\": 2.0}",
|
||||||
test_handle.output[0])
|
test_handle.output[0])
|
||||||
|
|
||||||
|
def test_single_level_wildcard_topic(self):
|
||||||
|
"""Test single level wildcard topic."""
|
||||||
|
dev_id = 'zanzito'
|
||||||
|
subscription = 'location/+/zanzito'
|
||||||
|
topic = 'location/room/zanzito'
|
||||||
|
location = json.dumps(LOCATION_MESSAGE)
|
||||||
|
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt_json',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
state = self.hass.states.get('device_tracker.zanzito')
|
||||||
|
self.assertEqual(state.attributes.get('latitude'), 2.0)
|
||||||
|
self.assertEqual(state.attributes.get('longitude'), 1.0)
|
||||||
|
|
||||||
|
def test_multi_level_wildcard_topic(self):
|
||||||
|
"""Test multi level wildcard topic."""
|
||||||
|
dev_id = 'zanzito'
|
||||||
|
subscription = 'location/#'
|
||||||
|
topic = 'location/zanzito'
|
||||||
|
location = json.dumps(LOCATION_MESSAGE)
|
||||||
|
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt_json',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
state = self.hass.states.get('device_tracker.zanzito')
|
||||||
|
self.assertEqual(state.attributes.get('latitude'), 2.0)
|
||||||
|
self.assertEqual(state.attributes.get('longitude'), 1.0)
|
||||||
|
|
||||||
|
def test_single_level_wildcard_topic_not_matching(self):
|
||||||
|
"""Test not matching single level wildcard topic."""
|
||||||
|
dev_id = 'zanzito'
|
||||||
|
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||||
|
subscription = 'location/+/zanzito'
|
||||||
|
topic = 'location/zanzito'
|
||||||
|
location = json.dumps(LOCATION_MESSAGE)
|
||||||
|
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt_json',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertIsNone(self.hass.states.get(entity_id))
|
||||||
|
|
||||||
|
def test_multi_level_wildcard_topic_not_matching(self):
|
||||||
|
"""Test not matching multi level wildcard topic."""
|
||||||
|
dev_id = 'zanzito'
|
||||||
|
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||||
|
subscription = 'location/#'
|
||||||
|
topic = 'somewhere/zanzito'
|
||||||
|
location = json.dumps(LOCATION_MESSAGE)
|
||||||
|
|
||||||
|
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
|
device_tracker.DOMAIN: {
|
||||||
|
CONF_PLATFORM: 'mqtt_json',
|
||||||
|
'devices': {dev_id: subscription}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fire_mqtt_message(self.hass, topic, location)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
self.assertIsNone(self.hass.states.get(entity_id))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue