Fix restore of MQTT subscriptions from reload (#88220)
This commit is contained in:
parent
d03655cb6f
commit
d2277fa6db
3 changed files with 150 additions and 41 deletions
|
@ -369,7 +369,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
mqtt_data.client = MQTT(hass, entry, conf)
|
||||
# Restore saved subscriptions
|
||||
if mqtt_data.subscriptions_to_restore:
|
||||
mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore
|
||||
mqtt_data.client.async_restore_tracked_subscriptions(
|
||||
mqtt_data.subscriptions_to_restore
|
||||
)
|
||||
mqtt_data.subscriptions_to_restore = []
|
||||
mqtt_data.reload_dispatchers.append(
|
||||
entry.add_update_listener(_async_config_entry_updated)
|
||||
|
@ -730,7 +732,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
await mqtt_client.async_disconnect()
|
||||
# Store remaining subscriptions to be able to restore or reload them
|
||||
# when the entry is set up again
|
||||
if mqtt_client.subscriptions:
|
||||
mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions
|
||||
if subscriptions := mqtt_client.subscriptions:
|
||||
mqtt_data.subscriptions_to_restore = subscriptions
|
||||
|
||||
return True
|
||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from functools import lru_cache, partial, wraps
|
||||
import inspect
|
||||
from itertools import groupby
|
||||
from itertools import chain, groupby
|
||||
import logging
|
||||
from operator import attrgetter
|
||||
import ssl
|
||||
|
@ -341,6 +341,11 @@ class MqttClientSetup:
|
|||
return self._client
|
||||
|
||||
|
||||
def _is_simple_match(topic: str) -> bool:
|
||||
"""Return if a topic is a simple match."""
|
||||
return not ("+" in topic or "#" in topic)
|
||||
|
||||
|
||||
class MQTT:
|
||||
"""Home Assistant MQTT client."""
|
||||
|
||||
|
@ -358,7 +363,6 @@ class MQTT:
|
|||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.conf = conf
|
||||
self.subscriptions: list[Subscription] = []
|
||||
self._simple_subscriptions: dict[str, list[Subscription]] = {}
|
||||
self._wildcard_subscriptions: list[Subscription] = []
|
||||
self.connected = False
|
||||
|
@ -390,6 +394,14 @@ class MQTT:
|
|||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
|
||||
)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> list[Subscription]:
|
||||
"""Return the tracked subscriptions."""
|
||||
return [
|
||||
*chain.from_iterable(self._simple_subscriptions.values()),
|
||||
*self._wildcard_subscriptions,
|
||||
]
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up listeners."""
|
||||
while self._cleanup_on_unload:
|
||||
|
@ -489,6 +501,50 @@ class MQTT:
|
|||
async with self._paho_lock:
|
||||
await self.hass.async_add_executor_job(stop)
|
||||
|
||||
@callback
|
||||
def async_restore_tracked_subscriptions(
|
||||
self, subscriptions: list[Subscription]
|
||||
) -> None:
|
||||
"""Restore tracked subscriptions after reload."""
|
||||
for subscription in subscriptions:
|
||||
self._async_track_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
@callback
|
||||
def _async_track_subscription(self, subscription: Subscription) -> None:
|
||||
"""Track a subscription.
|
||||
|
||||
This method does not send a SUBSCRIBE message to the broker.
|
||||
|
||||
The caller is responsible clearing the cache of _matching_subscriptions.
|
||||
"""
|
||||
if _is_simple_match(subscription.topic):
|
||||
self._simple_subscriptions.setdefault(subscription.topic, []).append(
|
||||
subscription
|
||||
)
|
||||
else:
|
||||
self._wildcard_subscriptions.append(subscription)
|
||||
|
||||
@callback
|
||||
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
||||
"""Untrack a subscription.
|
||||
|
||||
This method does not send an UNSUBSCRIBE message to the broker.
|
||||
|
||||
The caller is responsible clearing the cache of _matching_subscriptions.
|
||||
"""
|
||||
topic = subscription.topic
|
||||
try:
|
||||
if _is_simple_match(topic):
|
||||
simple_subscriptions = self._simple_subscriptions
|
||||
simple_subscriptions[topic].remove(subscription)
|
||||
if not simple_subscriptions[topic]:
|
||||
del simple_subscriptions[topic]
|
||||
else:
|
||||
self._wildcard_subscriptions.remove(subscription)
|
||||
except (KeyError, ValueError) as ex:
|
||||
raise HomeAssistantError("Can't remove subscription twice") from ex
|
||||
|
||||
async def async_subscribe(
|
||||
self,
|
||||
topic: str,
|
||||
|
@ -506,11 +562,7 @@ class MQTT:
|
|||
subscription = Subscription(
|
||||
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
|
||||
)
|
||||
self.subscriptions.append(subscription)
|
||||
if _is_simple := "+" not in topic and "#" not in topic:
|
||||
self._simple_subscriptions.setdefault(topic, []).append(subscription)
|
||||
else:
|
||||
self._wildcard_subscriptions.append(subscription)
|
||||
self._async_track_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
# Only subscribe if currently connected.
|
||||
|
@ -521,15 +573,7 @@ class MQTT:
|
|||
@callback
|
||||
def async_remove() -> None:
|
||||
"""Remove subscription."""
|
||||
if subscription not in self.subscriptions:
|
||||
raise HomeAssistantError("Can't remove subscription twice")
|
||||
self.subscriptions.remove(subscription)
|
||||
if _is_simple:
|
||||
self._simple_subscriptions[topic].remove(subscription)
|
||||
if not self._simple_subscriptions[topic]:
|
||||
del self._simple_subscriptions[topic]
|
||||
else:
|
||||
self._wildcard_subscriptions.remove(subscription)
|
||||
self._async_untrack_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
# Only unsubscribe if currently connected
|
||||
|
@ -636,18 +680,7 @@ class MQTT:
|
|||
result_code,
|
||||
)
|
||||
|
||||
# Group subscriptions to only re-subscribe once for each topic.
|
||||
keyfunc = attrgetter("topic")
|
||||
self.hass.add_job(
|
||||
self._async_perform_subscriptions,
|
||||
[
|
||||
# Re-subscribe with the highest requested qos
|
||||
(topic, max(subscription.qos for subscription in subs))
|
||||
for topic, subs in groupby(
|
||||
sorted(self.subscriptions, key=keyfunc), keyfunc
|
||||
)
|
||||
],
|
||||
)
|
||||
self.hass.create_task(self._async_resubscribe())
|
||||
|
||||
if (
|
||||
CONF_BIRTH_MESSAGE in self.conf
|
||||
|
@ -669,6 +702,20 @@ class MQTT:
|
|||
publish_birth_message(birth_message), self.hass.loop
|
||||
)
|
||||
|
||||
async def _async_resubscribe(self) -> None:
|
||||
"""Resubscribe on reconnect."""
|
||||
# Group subscriptions to only re-subscribe once for each topic.
|
||||
keyfunc = attrgetter("topic")
|
||||
await self._async_perform_subscriptions(
|
||||
[
|
||||
# Re-subscribe with the highest requested qos
|
||||
(topic, max(subscription.qos for subscription in subs))
|
||||
for topic, subs in groupby(
|
||||
sorted(self.subscriptions, key=keyfunc), keyfunc
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||
) -> None:
|
||||
|
|
|
@ -1446,6 +1446,69 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
|||
assert mqtt_client_mock.subscribe.mock_calls == expected
|
||||
|
||||
|
||||
async def test_reload_entry_with_restored_subscriptions(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mqtt_client_mock: MqttMockPahoClient,
|
||||
record_calls: MessageCallbackType,
|
||||
calls: list[ReceiveMessage],
|
||||
) -> None:
|
||||
"""Test reloading the config entry with with subscriptions restored."""
|
||||
|
||||
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||
entry.add_to_hass(hass)
|
||||
mqtt_client_mock.connect.return_value = 0
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
await mqtt.async_subscribe(hass, "wild/+/card", record_calls)
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[0].topic == "test-topic"
|
||||
assert calls[0].payload == "test-payload"
|
||||
assert calls[1].topic == "wild/any/card"
|
||||
assert calls[1].payload == "wild-card-payload"
|
||||
calls.clear()
|
||||
|
||||
# Reload the entry
|
||||
config_yaml_new = {}
|
||||
await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload2")
|
||||
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload2")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[0].topic == "test-topic"
|
||||
assert calls[0].payload == "test-payload2"
|
||||
assert calls[1].topic == "wild/any/card"
|
||||
assert calls[1].payload == "wild-card-payload2"
|
||||
calls.clear()
|
||||
|
||||
# Reload the entry again
|
||||
config_yaml_new = {}
|
||||
await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload3")
|
||||
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload3")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[0].topic == "test-topic"
|
||||
assert calls[0].payload == "test-payload3"
|
||||
assert calls[1].topic == "wild/any/card"
|
||||
assert calls[1].payload == "wild-card-payload3"
|
||||
|
||||
|
||||
async def test_initial_setup_logs_error(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
|
@ -2051,19 +2114,16 @@ async def test_mqtt_subscribes_topics_on_connect(
|
|||
await mqtt.async_subscribe(hass, "still/pending", record_calls)
|
||||
await mqtt.async_subscribe(hass, "still/pending", record_calls, 1)
|
||||
|
||||
with patch.object(hass, "add_job") as hass_jobs:
|
||||
mqtt_client_mock.on_connect(None, None, 0, 0)
|
||||
mqtt_client_mock.on_connect(None, None, 0, 0)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mqtt_client_mock.disconnect.call_count == 0
|
||||
assert mqtt_client_mock.disconnect.call_count == 0
|
||||
|
||||
assert len(hass_jobs.mock_calls) == 1
|
||||
assert set(hass_jobs.mock_calls[0][1][1]) == {
|
||||
("home/sensor", 2),
|
||||
("still/pending", 1),
|
||||
("topic/test", 0),
|
||||
}
|
||||
assert mqtt_client_mock.subscribe.call_count == 3
|
||||
mqtt_client_mock.subscribe.assert_any_call("topic/test", 0)
|
||||
mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2)
|
||||
mqtt_client_mock.subscribe.assert_any_call("still/pending", 1)
|
||||
|
||||
|
||||
async def test_setup_entry_with_config_override(
|
||||
|
|
Loading…
Add table
Reference in a new issue