From 4e49bd0596094fa2eab6da40e4562109f8fdd41c Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 10 Nov 2020 21:55:26 +0100 Subject: [PATCH] Correct handling of existing MQTT subscriptions (#43056) --- homeassistant/components/mqtt/subscription.py | 1 + tests/components/mqtt/test_subscription.py | 28 ++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index feccdc33bc2..24c1c6ff3a1 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -29,6 +29,7 @@ class EntitySubscription: async def resubscribe_if_necessary(self, hass, other): """Re-subscribe to the new topic if necessary.""" if not self._should_resubscribe(other): + self.unsubscribe_callback = other.unsubscribe_callback return if other is not None and other.unsubscribe_callback is not None: diff --git a/tests/components/mqtt/test_subscription.py b/tests/components/mqtt/test_subscription.py index 19797f28c3f..f1ed26e89cc 100644 --- a/tests/components/mqtt/test_subscription.py +++ b/tests/components/mqtt/test_subscription.py @@ -160,21 +160,35 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog): async def test_no_change(hass, mqtt_mock, caplog): """Test subscription to topics without change.""" + calls = [] + @callback - def msg_callback(*args): - """Do nothing.""" - pass + def record_calls(*args): + """Record calls.""" + calls.append(args) sub_state = None sub_state = await async_subscribe_topics( hass, sub_state, - {"test_topic1": {"topic": "test-topic1", "msg_callback": msg_callback}}, + {"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}}, ) - call_count = mqtt_mock.async_subscribe.call_count + subscribe_call_count = mqtt_mock.async_subscribe.call_count + + async_fire_mqtt_message(hass, "test-topic1", "test-payload") + assert len(calls) == 1 + sub_state = await async_subscribe_topics( hass, sub_state, - {"test_topic1": {"topic": "test-topic1", "msg_callback": msg_callback}}, + {"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}}, ) - assert call_count == mqtt_mock.async_subscribe.call_count + assert subscribe_call_count == mqtt_mock.async_subscribe.call_count + + async_fire_mqtt_message(hass, "test-topic1", "test-payload") + assert len(calls) == 2 + + await async_unsubscribe_topics(hass, sub_state) + + async_fire_mqtt_message(hass, "test-topic1", "test-payload") + assert len(calls) == 2