Fix restore of MQTT subscriptions from reload (#88220)

This commit is contained in:
J. Nick Koston 2023-02-16 11:14:26 -06:00 committed by GitHub
parent d03655cb6f
commit d2277fa6db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 150 additions and 41 deletions

View file

@ -369,7 +369,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
mqtt_data.client = MQTT(hass, entry, conf) mqtt_data.client = MQTT(hass, entry, conf)
# Restore saved subscriptions # Restore saved subscriptions
if mqtt_data.subscriptions_to_restore: if mqtt_data.subscriptions_to_restore:
mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore mqtt_data.client.async_restore_tracked_subscriptions(
mqtt_data.subscriptions_to_restore
)
mqtt_data.subscriptions_to_restore = [] mqtt_data.subscriptions_to_restore = []
mqtt_data.reload_dispatchers.append( mqtt_data.reload_dispatchers.append(
entry.add_update_listener(_async_config_entry_updated) entry.add_update_listener(_async_config_entry_updated)
@ -730,7 +732,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await mqtt_client.async_disconnect() await mqtt_client.async_disconnect()
# Store remaining subscriptions to be able to restore or reload them # Store remaining subscriptions to be able to restore or reload them
# when the entry is set up again # when the entry is set up again
if mqtt_client.subscriptions: if subscriptions := mqtt_client.subscriptions:
mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions mqtt_data.subscriptions_to_restore = subscriptions
return True return True

View file

@ -5,7 +5,7 @@ import asyncio
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
import inspect import inspect
from itertools import groupby from itertools import chain, groupby
import logging import logging
from operator import attrgetter from operator import attrgetter
import ssl import ssl
@ -341,6 +341,11 @@ class MqttClientSetup:
return self._client return self._client
def _is_simple_match(topic: str) -> bool:
"""Return if a topic is a simple match."""
return not ("+" in topic or "#" in topic)
class MQTT: class MQTT:
"""Home Assistant MQTT client.""" """Home Assistant MQTT client."""
@ -358,7 +363,6 @@ class MQTT:
self.hass = hass self.hass = hass
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
self.subscriptions: list[Subscription] = []
self._simple_subscriptions: dict[str, list[Subscription]] = {} self._simple_subscriptions: dict[str, list[Subscription]] = {}
self._wildcard_subscriptions: list[Subscription] = [] self._wildcard_subscriptions: list[Subscription] = []
self.connected = False self.connected = False
@ -390,6 +394,14 @@ class MQTT:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
) )
@property
def subscriptions(self) -> list[Subscription]:
"""Return the tracked subscriptions."""
return [
*chain.from_iterable(self._simple_subscriptions.values()),
*self._wildcard_subscriptions,
]
def cleanup(self) -> None: def cleanup(self) -> None:
"""Clean up listeners.""" """Clean up listeners."""
while self._cleanup_on_unload: while self._cleanup_on_unload:
@ -489,6 +501,50 @@ class MQTT:
async with self._paho_lock: async with self._paho_lock:
await self.hass.async_add_executor_job(stop) await self.hass.async_add_executor_job(stop)
@callback
def async_restore_tracked_subscriptions(
self, subscriptions: list[Subscription]
) -> None:
"""Restore tracked subscriptions after reload."""
for subscription in subscriptions:
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
@callback
def _async_track_subscription(self, subscription: Subscription) -> None:
"""Track a subscription.
This method does not send a SUBSCRIBE message to the broker.
The caller is responsible clearing the cache of _matching_subscriptions.
"""
if _is_simple_match(subscription.topic):
self._simple_subscriptions.setdefault(subscription.topic, []).append(
subscription
)
else:
self._wildcard_subscriptions.append(subscription)
@callback
def _async_untrack_subscription(self, subscription: Subscription) -> None:
"""Untrack a subscription.
This method does not send an UNSUBSCRIBE message to the broker.
The caller is responsible clearing the cache of _matching_subscriptions.
"""
topic = subscription.topic
try:
if _is_simple_match(topic):
simple_subscriptions = self._simple_subscriptions
simple_subscriptions[topic].remove(subscription)
if not simple_subscriptions[topic]:
del simple_subscriptions[topic]
else:
self._wildcard_subscriptions.remove(subscription)
except (KeyError, ValueError) as ex:
raise HomeAssistantError("Can't remove subscription twice") from ex
async def async_subscribe( async def async_subscribe(
self, self,
topic: str, topic: str,
@ -506,11 +562,7 @@ class MQTT:
subscription = Subscription( subscription = Subscription(
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
) )
self.subscriptions.append(subscription) self._async_track_subscription(subscription)
if _is_simple := "+" not in topic and "#" not in topic:
self._simple_subscriptions.setdefault(topic, []).append(subscription)
else:
self._wildcard_subscriptions.append(subscription)
self._matching_subscriptions.cache_clear() self._matching_subscriptions.cache_clear()
# Only subscribe if currently connected. # Only subscribe if currently connected.
@ -521,15 +573,7 @@ class MQTT:
@callback @callback
def async_remove() -> None: def async_remove() -> None:
"""Remove subscription.""" """Remove subscription."""
if subscription not in self.subscriptions: self._async_untrack_subscription(subscription)
raise HomeAssistantError("Can't remove subscription twice")
self.subscriptions.remove(subscription)
if _is_simple:
self._simple_subscriptions[topic].remove(subscription)
if not self._simple_subscriptions[topic]:
del self._simple_subscriptions[topic]
else:
self._wildcard_subscriptions.remove(subscription)
self._matching_subscriptions.cache_clear() self._matching_subscriptions.cache_clear()
# Only unsubscribe if currently connected # Only unsubscribe if currently connected
@ -636,18 +680,7 @@ class MQTT:
result_code, result_code,
) )
# Group subscriptions to only re-subscribe once for each topic. self.hass.create_task(self._async_resubscribe())
keyfunc = attrgetter("topic")
self.hass.add_job(
self._async_perform_subscriptions,
[
# Re-subscribe with the highest requested qos
(topic, max(subscription.qos for subscription in subs))
for topic, subs in groupby(
sorted(self.subscriptions, key=keyfunc), keyfunc
)
],
)
if ( if (
CONF_BIRTH_MESSAGE in self.conf CONF_BIRTH_MESSAGE in self.conf
@ -669,6 +702,20 @@ class MQTT:
publish_birth_message(birth_message), self.hass.loop publish_birth_message(birth_message), self.hass.loop
) )
async def _async_resubscribe(self) -> None:
"""Resubscribe on reconnect."""
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic")
await self._async_perform_subscriptions(
[
# Re-subscribe with the highest requested qos
(topic, max(subscription.qos for subscription in subs))
for topic, subs in groupby(
sorted(self.subscriptions, key=keyfunc), keyfunc
)
]
)
def _mqtt_on_message( def _mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None: ) -> None:

View file

@ -1446,6 +1446,69 @@ async def test_restore_all_active_subscriptions_on_reconnect(
assert mqtt_client_mock.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
async def test_reload_entry_with_restored_subscriptions(
hass: HomeAssistant,
tmp_path: Path,
mqtt_client_mock: MqttMockPahoClient,
record_calls: MessageCallbackType,
calls: list[ReceiveMessage],
) -> None:
"""Test reloading the config entry with with subscriptions restored."""
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
entry.add_to_hass(hass)
mqtt_client_mock.connect.return_value = 0
assert await mqtt.async_setup_entry(hass, entry)
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test-topic", record_calls)
await mqtt.async_subscribe(hass, "wild/+/card", record_calls)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].topic == "test-topic"
assert calls[0].payload == "test-payload"
assert calls[1].topic == "wild/any/card"
assert calls[1].payload == "wild-card-payload"
calls.clear()
# Reload the entry
config_yaml_new = {}
await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test-topic", "test-payload2")
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload2")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].topic == "test-topic"
assert calls[0].payload == "test-payload2"
assert calls[1].topic == "wild/any/card"
assert calls[1].payload == "wild-card-payload2"
calls.clear()
# Reload the entry again
config_yaml_new = {}
await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test-topic", "test-payload3")
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload3")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].topic == "test-topic"
assert calls[0].payload == "test-payload3"
assert calls[1].topic == "wild/any/card"
assert calls[1].payload == "wild-card-payload3"
async def test_initial_setup_logs_error( async def test_initial_setup_logs_error(
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
@ -2051,19 +2114,16 @@ async def test_mqtt_subscribes_topics_on_connect(
await mqtt.async_subscribe(hass, "still/pending", record_calls) await mqtt.async_subscribe(hass, "still/pending", record_calls)
await mqtt.async_subscribe(hass, "still/pending", record_calls, 1) await mqtt.async_subscribe(hass, "still/pending", record_calls, 1)
with patch.object(hass, "add_job") as hass_jobs: mqtt_client_mock.on_connect(None, None, 0, 0)
mqtt_client_mock.on_connect(None, None, 0, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mqtt_client_mock.disconnect.call_count == 0 assert mqtt_client_mock.disconnect.call_count == 0
assert len(hass_jobs.mock_calls) == 1 assert mqtt_client_mock.subscribe.call_count == 3
assert set(hass_jobs.mock_calls[0][1][1]) == { mqtt_client_mock.subscribe.assert_any_call("topic/test", 0)
("home/sensor", 2), mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2)
("still/pending", 1), mqtt_client_mock.subscribe.assert_any_call("still/pending", 1)
("topic/test", 0),
}
async def test_setup_entry_with_config_override( async def test_setup_entry_with_config_override(