diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 458c5952a69..89c003c070c 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -10,6 +10,7 @@ import os import socket import time import ssl +import re import requests.certs import voluptuous as vol @@ -639,12 +640,20 @@ def _raise_on_error(result): def _match_topic(subscription, topic): """Test if topic matches subscription.""" + reg_ex_parts = [] + suffix = "" if subscription.endswith('#'): - return (subscription[:-2] == topic or - topic.startswith(subscription[:-1])) - + subscription = subscription[:-2] + suffix = "(.*)" sub_parts = subscription.split('/') - topic_parts = topic.split('/') + for sub_part in sub_parts: + if sub_part == "+": + reg_ex_parts.append(r"([^\/]+)") + else: + reg_ex_parts.append(sub_part) - return (len(sub_parts) == len(topic_parts) and - all(a == b for a, b in zip(sub_parts, topic_parts) if a != '+')) + reg_ex = "^" + (r'\/'.join(reg_ex_parts)) + suffix + "$" + + reg = re.compile(reg_ex) + + return reg.match(topic) is not None diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 0017674e82f..0ef512edcd6 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -209,6 +209,46 @@ class TestMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(0, len(self.calls)) + def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self): + """Test the subscription of wildcard topics.""" + mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls) + + fire_mqtt_message(self.hass, 'hi/test-topic', 'test-payload') + + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + self.assertEqual('hi/test-topic', self.calls[0][0]) + self.assertEqual('test-payload', self.calls[0][1]) + + def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self): + """Test the subscription of wildcard topics.""" + mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls) + + fire_mqtt_message(self.hass, 'hi/test-topic/here-iam', 'test-payload') + + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + self.assertEqual('hi/test-topic/here-iam', self.calls[0][0]) + self.assertEqual('test-payload', self.calls[0][1]) + + def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self): + """Test the subscription of wildcard topics.""" + mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls) + + fire_mqtt_message(self.hass, 'hi/here-iam/test-topic', 'test-payload') + + self.hass.block_till_done() + self.assertEqual(0, len(self.calls)) + + def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self): + """Test the subscription of wildcard topics.""" + mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls) + + fire_mqtt_message(self.hass, 'hi/another-test-topic', 'test-payload') + + self.hass.block_till_done() + self.assertEqual(0, len(self.calls)) + def test_subscribe_binary_topic(self): """Test the subscription to a binary topic.""" mqtt.subscribe(self.hass, 'test-topic', self.record_calls,