From d2277fa6db01a204d68523b2c7d4eebebd0d12f9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Feb 2023 11:14:26 -0600 Subject: [PATCH] Fix restore of MQTT subscriptions from reload (#88220) --- homeassistant/components/mqtt/__init__.py | 8 +- homeassistant/components/mqtt/client.py | 103 ++++++++++++++++------ tests/components/mqtt/test_init.py | 80 ++++++++++++++--- 3 files changed, 150 insertions(+), 41 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 7096a473ec0..a1b194284c7 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -369,7 +369,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: mqtt_data.client = MQTT(hass, entry, conf) # Restore saved subscriptions if mqtt_data.subscriptions_to_restore: - mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore + mqtt_data.client.async_restore_tracked_subscriptions( + mqtt_data.subscriptions_to_restore + ) mqtt_data.subscriptions_to_restore = [] mqtt_data.reload_dispatchers.append( entry.add_update_listener(_async_config_entry_updated) @@ -730,7 +732,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await mqtt_client.async_disconnect() # Store remaining subscriptions to be able to restore or reload them # when the entry is set up again - if mqtt_client.subscriptions: - mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions + if subscriptions := mqtt_client.subscriptions: + mqtt_data.subscriptions_to_restore = subscriptions return True diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 6355e992a0d..ec866169709 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -5,7 +5,7 @@ import asyncio from collections.abc import Callable, Coroutine, Iterable from functools import lru_cache, partial, wraps import inspect -from itertools import groupby +from itertools import chain, groupby import logging from operator import attrgetter import ssl @@ -341,6 +341,11 @@ class MqttClientSetup: return self._client +def _is_simple_match(topic: str) -> bool: + """Return if a topic is a simple match.""" + return not ("+" in topic or "#" in topic) + + class MQTT: """Home Assistant MQTT client.""" @@ -358,7 +363,6 @@ class MQTT: self.hass = hass self.config_entry = config_entry self.conf = conf - self.subscriptions: list[Subscription] = [] self._simple_subscriptions: dict[str, list[Subscription]] = {} self._wildcard_subscriptions: list[Subscription] = [] self.connected = False @@ -390,6 +394,14 @@ class MQTT: hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt) ) + @property + def subscriptions(self) -> list[Subscription]: + """Return the tracked subscriptions.""" + return [ + *chain.from_iterable(self._simple_subscriptions.values()), + *self._wildcard_subscriptions, + ] + def cleanup(self) -> None: """Clean up listeners.""" while self._cleanup_on_unload: @@ -489,6 +501,50 @@ class MQTT: async with self._paho_lock: await self.hass.async_add_executor_job(stop) + @callback + def async_restore_tracked_subscriptions( + self, subscriptions: list[Subscription] + ) -> None: + """Restore tracked subscriptions after reload.""" + for subscription in subscriptions: + self._async_track_subscription(subscription) + self._matching_subscriptions.cache_clear() + + @callback + def _async_track_subscription(self, subscription: Subscription) -> None: + """Track a subscription. + + This method does not send a SUBSCRIBE message to the broker. + + The caller is responsible clearing the cache of _matching_subscriptions. + """ + if _is_simple_match(subscription.topic): + self._simple_subscriptions.setdefault(subscription.topic, []).append( + subscription + ) + else: + self._wildcard_subscriptions.append(subscription) + + @callback + def _async_untrack_subscription(self, subscription: Subscription) -> None: + """Untrack a subscription. + + This method does not send an UNSUBSCRIBE message to the broker. + + The caller is responsible clearing the cache of _matching_subscriptions. + """ + topic = subscription.topic + try: + if _is_simple_match(topic): + simple_subscriptions = self._simple_subscriptions + simple_subscriptions[topic].remove(subscription) + if not simple_subscriptions[topic]: + del simple_subscriptions[topic] + else: + self._wildcard_subscriptions.remove(subscription) + except (KeyError, ValueError) as ex: + raise HomeAssistantError("Can't remove subscription twice") from ex + async def async_subscribe( self, topic: str, @@ -506,11 +562,7 @@ class MQTT: subscription = Subscription( topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding ) - self.subscriptions.append(subscription) - if _is_simple := "+" not in topic and "#" not in topic: - self._simple_subscriptions.setdefault(topic, []).append(subscription) - else: - self._wildcard_subscriptions.append(subscription) + self._async_track_subscription(subscription) self._matching_subscriptions.cache_clear() # Only subscribe if currently connected. @@ -521,15 +573,7 @@ class MQTT: @callback def async_remove() -> None: """Remove subscription.""" - if subscription not in self.subscriptions: - raise HomeAssistantError("Can't remove subscription twice") - self.subscriptions.remove(subscription) - if _is_simple: - self._simple_subscriptions[topic].remove(subscription) - if not self._simple_subscriptions[topic]: - del self._simple_subscriptions[topic] - else: - self._wildcard_subscriptions.remove(subscription) + self._async_untrack_subscription(subscription) self._matching_subscriptions.cache_clear() # Only unsubscribe if currently connected @@ -636,18 +680,7 @@ class MQTT: result_code, ) - # Group subscriptions to only re-subscribe once for each topic. - keyfunc = attrgetter("topic") - self.hass.add_job( - self._async_perform_subscriptions, - [ - # Re-subscribe with the highest requested qos - (topic, max(subscription.qos for subscription in subs)) - for topic, subs in groupby( - sorted(self.subscriptions, key=keyfunc), keyfunc - ) - ], - ) + self.hass.create_task(self._async_resubscribe()) if ( CONF_BIRTH_MESSAGE in self.conf @@ -669,6 +702,20 @@ class MQTT: publish_birth_message(birth_message), self.hass.loop ) + async def _async_resubscribe(self) -> None: + """Resubscribe on reconnect.""" + # Group subscriptions to only re-subscribe once for each topic. + keyfunc = attrgetter("topic") + await self._async_perform_subscriptions( + [ + # Re-subscribe with the highest requested qos + (topic, max(subscription.qos for subscription in subs)) + for topic, subs in groupby( + sorted(self.subscriptions, key=keyfunc), keyfunc + ) + ] + ) + def _mqtt_on_message( self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage ) -> None: diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index b3e919f845d..58b8279f836 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1446,6 +1446,69 @@ async def test_restore_all_active_subscriptions_on_reconnect( assert mqtt_client_mock.subscribe.mock_calls == expected +async def test_reload_entry_with_restored_subscriptions( + hass: HomeAssistant, + tmp_path: Path, + mqtt_client_mock: MqttMockPahoClient, + record_calls: MessageCallbackType, + calls: list[ReceiveMessage], +) -> None: + """Test reloading the config entry with with subscriptions restored.""" + + entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) + entry.add_to_hass(hass) + mqtt_client_mock.connect.return_value = 0 + assert await mqtt.async_setup_entry(hass, entry) + await hass.async_block_till_done() + + await mqtt.async_subscribe(hass, "test-topic", record_calls) + await mqtt.async_subscribe(hass, "wild/+/card", record_calls) + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload") + + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[0].topic == "test-topic" + assert calls[0].payload == "test-payload" + assert calls[1].topic == "wild/any/card" + assert calls[1].payload == "wild-card-payload" + calls.clear() + + # Reload the entry + config_yaml_new = {} + await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new) + + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "test-topic", "test-payload2") + async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload2") + + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[0].topic == "test-topic" + assert calls[0].payload == "test-payload2" + assert calls[1].topic == "wild/any/card" + assert calls[1].payload == "wild-card-payload2" + calls.clear() + + # Reload the entry again + config_yaml_new = {} + await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new) + + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "test-topic", "test-payload3") + async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload3") + + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[0].topic == "test-topic" + assert calls[0].payload == "test-payload3" + assert calls[1].topic == "wild/any/card" + assert calls[1].payload == "wild-card-payload3" + + async def test_initial_setup_logs_error( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, @@ -2051,19 +2114,16 @@ async def test_mqtt_subscribes_topics_on_connect( await mqtt.async_subscribe(hass, "still/pending", record_calls) await mqtt.async_subscribe(hass, "still/pending", record_calls, 1) - with patch.object(hass, "add_job") as hass_jobs: - mqtt_client_mock.on_connect(None, None, 0, 0) + mqtt_client_mock.on_connect(None, None, 0, 0) - await hass.async_block_till_done() + await hass.async_block_till_done() - assert mqtt_client_mock.disconnect.call_count == 0 + assert mqtt_client_mock.disconnect.call_count == 0 - assert len(hass_jobs.mock_calls) == 1 - assert set(hass_jobs.mock_calls[0][1][1]) == { - ("home/sensor", 2), - ("still/pending", 1), - ("topic/test", 0), - } + assert mqtt_client_mock.subscribe.call_count == 3 + mqtt_client_mock.subscribe.assert_any_call("topic/test", 0) + mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2) + mqtt_client_mock.subscribe.assert_any_call("still/pending", 1) async def test_setup_entry_with_config_override(