correct MQTT subscription filter (#7269)

* correct MQTT subscription filter

* wildcard handling (#) fixed

* wildcard handling (#) fixed

* added tests for topic subscription like +/something/#

* function names changed (line too long)

* using raw strings for regular expression
import order changed
This commit is contained in:
amigian74 2017-05-02 18:18:34 +02:00 committed by Paulus Schoutsen
parent 570c5549a9
commit 0e08925259
2 changed files with 55 additions and 6 deletions

View file

@ -10,6 +10,7 @@ import os
import socket import socket
import time import time
import ssl import ssl
import re
import requests.certs import requests.certs
import voluptuous as vol import voluptuous as vol
@ -639,12 +640,20 @@ def _raise_on_error(result):
def _match_topic(subscription, topic): def _match_topic(subscription, topic):
"""Test if topic matches subscription.""" """Test if topic matches subscription."""
reg_ex_parts = []
suffix = ""
if subscription.endswith('#'): if subscription.endswith('#'):
return (subscription[:-2] == topic or subscription = subscription[:-2]
topic.startswith(subscription[:-1])) suffix = "(.*)"
sub_parts = subscription.split('/') sub_parts = subscription.split('/')
topic_parts = topic.split('/') for sub_part in sub_parts:
if sub_part == "+":
reg_ex_parts.append(r"([^\/]+)")
else:
reg_ex_parts.append(sub_part)
return (len(sub_parts) == len(topic_parts) and reg_ex = "^" + (r'\/'.join(reg_ex_parts)) + suffix + "$"
all(a == b for a, b in zip(sub_parts, topic_parts) if a != '+'))
reg = re.compile(reg_ex)
return reg.match(topic) is not None

View file

@ -209,6 +209,46 @@ class TestMQTT(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(0, len(self.calls)) self.assertEqual(0, len(self.calls))
def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('hi/test-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/test-topic/here-iam', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('hi/test-topic/here-iam', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/here-iam/test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/another-test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_binary_topic(self): def test_subscribe_binary_topic(self):
"""Test the subscription to a binary topic.""" """Test the subscription to a binary topic."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls, mqtt.subscribe(self.hass, 'test-topic', self.record_calls,