correct MQTT subscription filter (#7269)
* correct MQTT subscription filter * wildcard handling (#) fixed * wildcard handling (#) fixed * added tests for topic subscription like +/something/# * function names changed (line too long) * using raw strings for regular expression import order changed
This commit is contained in:
parent
570c5549a9
commit
0e08925259
2 changed files with 55 additions and 6 deletions
|
@ -10,6 +10,7 @@ import os
|
|||
import socket
|
||||
import time
|
||||
import ssl
|
||||
import re
|
||||
import requests.certs
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -639,12 +640,20 @@ def _raise_on_error(result):
|
|||
|
||||
def _match_topic(subscription, topic):
|
||||
"""Test if topic matches subscription."""
|
||||
reg_ex_parts = []
|
||||
suffix = ""
|
||||
if subscription.endswith('#'):
|
||||
return (subscription[:-2] == topic or
|
||||
topic.startswith(subscription[:-1]))
|
||||
|
||||
subscription = subscription[:-2]
|
||||
suffix = "(.*)"
|
||||
sub_parts = subscription.split('/')
|
||||
topic_parts = topic.split('/')
|
||||
for sub_part in sub_parts:
|
||||
if sub_part == "+":
|
||||
reg_ex_parts.append(r"([^\/]+)")
|
||||
else:
|
||||
reg_ex_parts.append(sub_part)
|
||||
|
||||
return (len(sub_parts) == len(topic_parts) and
|
||||
all(a == b for a, b in zip(sub_parts, topic_parts) if a != '+'))
|
||||
reg_ex = "^" + (r'\/'.join(reg_ex_parts)) + suffix + "$"
|
||||
|
||||
reg = re.compile(reg_ex)
|
||||
|
||||
return reg.match(topic) is not None
|
||||
|
|
|
@ -209,6 +209,46 @@ class TestMQTT(unittest.TestCase):
|
|||
self.hass.block_till_done()
|
||||
self.assertEqual(0, len(self.calls))
|
||||
|
||||
def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self):
|
||||
"""Test the subscription of wildcard topics."""
|
||||
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
|
||||
|
||||
fire_mqtt_message(self.hass, 'hi/test-topic', 'test-payload')
|
||||
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
self.assertEqual('hi/test-topic', self.calls[0][0])
|
||||
self.assertEqual('test-payload', self.calls[0][1])
|
||||
|
||||
def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self):
|
||||
"""Test the subscription of wildcard topics."""
|
||||
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
|
||||
|
||||
fire_mqtt_message(self.hass, 'hi/test-topic/here-iam', 'test-payload')
|
||||
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
self.assertEqual('hi/test-topic/here-iam', self.calls[0][0])
|
||||
self.assertEqual('test-payload', self.calls[0][1])
|
||||
|
||||
def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self):
|
||||
"""Test the subscription of wildcard topics."""
|
||||
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
|
||||
|
||||
fire_mqtt_message(self.hass, 'hi/here-iam/test-topic', 'test-payload')
|
||||
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(0, len(self.calls))
|
||||
|
||||
def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self):
|
||||
"""Test the subscription of wildcard topics."""
|
||||
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
|
||||
|
||||
fire_mqtt_message(self.hass, 'hi/another-test-topic', 'test-payload')
|
||||
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(0, len(self.calls))
|
||||
|
||||
def test_subscribe_binary_topic(self):
|
||||
"""Test the subscription to a binary topic."""
|
||||
mqtt.subscribe(self.hass, 'test-topic', self.record_calls,
|
||||
|
|
Loading…
Add table
Reference in a new issue