Fix restore of MQTT subscriptions from reload (#88220)
This commit is contained in:
parent
d03655cb6f
commit
d2277fa6db
3 changed files with 150 additions and 41 deletions
|
@ -369,7 +369,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
mqtt_data.client = MQTT(hass, entry, conf)
|
mqtt_data.client = MQTT(hass, entry, conf)
|
||||||
# Restore saved subscriptions
|
# Restore saved subscriptions
|
||||||
if mqtt_data.subscriptions_to_restore:
|
if mqtt_data.subscriptions_to_restore:
|
||||||
mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore
|
mqtt_data.client.async_restore_tracked_subscriptions(
|
||||||
|
mqtt_data.subscriptions_to_restore
|
||||||
|
)
|
||||||
mqtt_data.subscriptions_to_restore = []
|
mqtt_data.subscriptions_to_restore = []
|
||||||
mqtt_data.reload_dispatchers.append(
|
mqtt_data.reload_dispatchers.append(
|
||||||
entry.add_update_listener(_async_config_entry_updated)
|
entry.add_update_listener(_async_config_entry_updated)
|
||||||
|
@ -730,7 +732,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
await mqtt_client.async_disconnect()
|
await mqtt_client.async_disconnect()
|
||||||
# Store remaining subscriptions to be able to restore or reload them
|
# Store remaining subscriptions to be able to restore or reload them
|
||||||
# when the entry is set up again
|
# when the entry is set up again
|
||||||
if mqtt_client.subscriptions:
|
if subscriptions := mqtt_client.subscriptions:
|
||||||
mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions
|
mqtt_data.subscriptions_to_restore = subscriptions
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
||||||
from collections.abc import Callable, Coroutine, Iterable
|
from collections.abc import Callable, Coroutine, Iterable
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
import inspect
|
import inspect
|
||||||
from itertools import groupby
|
from itertools import chain, groupby
|
||||||
import logging
|
import logging
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import ssl
|
import ssl
|
||||||
|
@ -341,6 +341,11 @@ class MqttClientSetup:
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
|
||||||
|
def _is_simple_match(topic: str) -> bool:
|
||||||
|
"""Return if a topic is a simple match."""
|
||||||
|
return not ("+" in topic or "#" in topic)
|
||||||
|
|
||||||
|
|
||||||
class MQTT:
|
class MQTT:
|
||||||
"""Home Assistant MQTT client."""
|
"""Home Assistant MQTT client."""
|
||||||
|
|
||||||
|
@ -358,7 +363,6 @@ class MQTT:
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
self.conf = conf
|
self.conf = conf
|
||||||
self.subscriptions: list[Subscription] = []
|
|
||||||
self._simple_subscriptions: dict[str, list[Subscription]] = {}
|
self._simple_subscriptions: dict[str, list[Subscription]] = {}
|
||||||
self._wildcard_subscriptions: list[Subscription] = []
|
self._wildcard_subscriptions: list[Subscription] = []
|
||||||
self.connected = False
|
self.connected = False
|
||||||
|
@ -390,6 +394,14 @@ class MQTT:
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subscriptions(self) -> list[Subscription]:
|
||||||
|
"""Return the tracked subscriptions."""
|
||||||
|
return [
|
||||||
|
*chain.from_iterable(self._simple_subscriptions.values()),
|
||||||
|
*self._wildcard_subscriptions,
|
||||||
|
]
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
"""Clean up listeners."""
|
"""Clean up listeners."""
|
||||||
while self._cleanup_on_unload:
|
while self._cleanup_on_unload:
|
||||||
|
@ -489,6 +501,50 @@ class MQTT:
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
await self.hass.async_add_executor_job(stop)
|
await self.hass.async_add_executor_job(stop)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_restore_tracked_subscriptions(
|
||||||
|
self, subscriptions: list[Subscription]
|
||||||
|
) -> None:
|
||||||
|
"""Restore tracked subscriptions after reload."""
|
||||||
|
for subscription in subscriptions:
|
||||||
|
self._async_track_subscription(subscription)
|
||||||
|
self._matching_subscriptions.cache_clear()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_track_subscription(self, subscription: Subscription) -> None:
|
||||||
|
"""Track a subscription.
|
||||||
|
|
||||||
|
This method does not send a SUBSCRIBE message to the broker.
|
||||||
|
|
||||||
|
The caller is responsible clearing the cache of _matching_subscriptions.
|
||||||
|
"""
|
||||||
|
if _is_simple_match(subscription.topic):
|
||||||
|
self._simple_subscriptions.setdefault(subscription.topic, []).append(
|
||||||
|
subscription
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._wildcard_subscriptions.append(subscription)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
||||||
|
"""Untrack a subscription.
|
||||||
|
|
||||||
|
This method does not send an UNSUBSCRIBE message to the broker.
|
||||||
|
|
||||||
|
The caller is responsible clearing the cache of _matching_subscriptions.
|
||||||
|
"""
|
||||||
|
topic = subscription.topic
|
||||||
|
try:
|
||||||
|
if _is_simple_match(topic):
|
||||||
|
simple_subscriptions = self._simple_subscriptions
|
||||||
|
simple_subscriptions[topic].remove(subscription)
|
||||||
|
if not simple_subscriptions[topic]:
|
||||||
|
del simple_subscriptions[topic]
|
||||||
|
else:
|
||||||
|
self._wildcard_subscriptions.remove(subscription)
|
||||||
|
except (KeyError, ValueError) as ex:
|
||||||
|
raise HomeAssistantError("Can't remove subscription twice") from ex
|
||||||
|
|
||||||
async def async_subscribe(
|
async def async_subscribe(
|
||||||
self,
|
self,
|
||||||
topic: str,
|
topic: str,
|
||||||
|
@ -506,11 +562,7 @@ class MQTT:
|
||||||
subscription = Subscription(
|
subscription = Subscription(
|
||||||
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
|
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
|
||||||
)
|
)
|
||||||
self.subscriptions.append(subscription)
|
self._async_track_subscription(subscription)
|
||||||
if _is_simple := "+" not in topic and "#" not in topic:
|
|
||||||
self._simple_subscriptions.setdefault(topic, []).append(subscription)
|
|
||||||
else:
|
|
||||||
self._wildcard_subscriptions.append(subscription)
|
|
||||||
self._matching_subscriptions.cache_clear()
|
self._matching_subscriptions.cache_clear()
|
||||||
|
|
||||||
# Only subscribe if currently connected.
|
# Only subscribe if currently connected.
|
||||||
|
@ -521,15 +573,7 @@ class MQTT:
|
||||||
@callback
|
@callback
|
||||||
def async_remove() -> None:
|
def async_remove() -> None:
|
||||||
"""Remove subscription."""
|
"""Remove subscription."""
|
||||||
if subscription not in self.subscriptions:
|
self._async_untrack_subscription(subscription)
|
||||||
raise HomeAssistantError("Can't remove subscription twice")
|
|
||||||
self.subscriptions.remove(subscription)
|
|
||||||
if _is_simple:
|
|
||||||
self._simple_subscriptions[topic].remove(subscription)
|
|
||||||
if not self._simple_subscriptions[topic]:
|
|
||||||
del self._simple_subscriptions[topic]
|
|
||||||
else:
|
|
||||||
self._wildcard_subscriptions.remove(subscription)
|
|
||||||
self._matching_subscriptions.cache_clear()
|
self._matching_subscriptions.cache_clear()
|
||||||
|
|
||||||
# Only unsubscribe if currently connected
|
# Only unsubscribe if currently connected
|
||||||
|
@ -636,18 +680,7 @@ class MQTT:
|
||||||
result_code,
|
result_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Group subscriptions to only re-subscribe once for each topic.
|
self.hass.create_task(self._async_resubscribe())
|
||||||
keyfunc = attrgetter("topic")
|
|
||||||
self.hass.add_job(
|
|
||||||
self._async_perform_subscriptions,
|
|
||||||
[
|
|
||||||
# Re-subscribe with the highest requested qos
|
|
||||||
(topic, max(subscription.qos for subscription in subs))
|
|
||||||
for topic, subs in groupby(
|
|
||||||
sorted(self.subscriptions, key=keyfunc), keyfunc
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
CONF_BIRTH_MESSAGE in self.conf
|
CONF_BIRTH_MESSAGE in self.conf
|
||||||
|
@ -669,6 +702,20 @@ class MQTT:
|
||||||
publish_birth_message(birth_message), self.hass.loop
|
publish_birth_message(birth_message), self.hass.loop
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_resubscribe(self) -> None:
|
||||||
|
"""Resubscribe on reconnect."""
|
||||||
|
# Group subscriptions to only re-subscribe once for each topic.
|
||||||
|
keyfunc = attrgetter("topic")
|
||||||
|
await self._async_perform_subscriptions(
|
||||||
|
[
|
||||||
|
# Re-subscribe with the highest requested qos
|
||||||
|
(topic, max(subscription.qos for subscription in subs))
|
||||||
|
for topic, subs in groupby(
|
||||||
|
sorted(self.subscriptions, key=keyfunc), keyfunc
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _mqtt_on_message(
|
def _mqtt_on_message(
|
||||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -1446,6 +1446,69 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||||
assert mqtt_client_mock.subscribe.mock_calls == expected
|
assert mqtt_client_mock.subscribe.mock_calls == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_entry_with_restored_subscriptions(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
mqtt_client_mock: MqttMockPahoClient,
|
||||||
|
record_calls: MessageCallbackType,
|
||||||
|
calls: list[ReceiveMessage],
|
||||||
|
) -> None:
|
||||||
|
"""Test reloading the config entry with with subscriptions restored."""
|
||||||
|
|
||||||
|
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
mqtt_client_mock.connect.return_value = 0
|
||||||
|
assert await mqtt.async_setup_entry(hass, entry)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
await mqtt.async_subscribe(hass, "wild/+/card", record_calls)
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||||
|
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload")
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0].topic == "test-topic"
|
||||||
|
assert calls[0].payload == "test-payload"
|
||||||
|
assert calls[1].topic == "wild/any/card"
|
||||||
|
assert calls[1].payload == "wild-card-payload"
|
||||||
|
calls.clear()
|
||||||
|
|
||||||
|
# Reload the entry
|
||||||
|
config_yaml_new = {}
|
||||||
|
await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, "test-topic", "test-payload2")
|
||||||
|
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload2")
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0].topic == "test-topic"
|
||||||
|
assert calls[0].payload == "test-payload2"
|
||||||
|
assert calls[1].topic == "wild/any/card"
|
||||||
|
assert calls[1].payload == "wild-card-payload2"
|
||||||
|
calls.clear()
|
||||||
|
|
||||||
|
# Reload the entry again
|
||||||
|
config_yaml_new = {}
|
||||||
|
await help_test_entry_reload_with_new_config(hass, tmp_path, config_yaml_new)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, "test-topic", "test-payload3")
|
||||||
|
async_fire_mqtt_message(hass, "wild/any/card", "wild-card-payload3")
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0].topic == "test-topic"
|
||||||
|
assert calls[0].payload == "test-payload3"
|
||||||
|
assert calls[1].topic == "wild/any/card"
|
||||||
|
assert calls[1].payload == "wild-card-payload3"
|
||||||
|
|
||||||
|
|
||||||
async def test_initial_setup_logs_error(
|
async def test_initial_setup_logs_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
@ -2051,19 +2114,16 @@ async def test_mqtt_subscribes_topics_on_connect(
|
||||||
await mqtt.async_subscribe(hass, "still/pending", record_calls)
|
await mqtt.async_subscribe(hass, "still/pending", record_calls)
|
||||||
await mqtt.async_subscribe(hass, "still/pending", record_calls, 1)
|
await mqtt.async_subscribe(hass, "still/pending", record_calls, 1)
|
||||||
|
|
||||||
with patch.object(hass, "add_job") as hass_jobs:
|
mqtt_client_mock.on_connect(None, None, 0, 0)
|
||||||
mqtt_client_mock.on_connect(None, None, 0, 0)
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert mqtt_client_mock.disconnect.call_count == 0
|
assert mqtt_client_mock.disconnect.call_count == 0
|
||||||
|
|
||||||
assert len(hass_jobs.mock_calls) == 1
|
assert mqtt_client_mock.subscribe.call_count == 3
|
||||||
assert set(hass_jobs.mock_calls[0][1][1]) == {
|
mqtt_client_mock.subscribe.assert_any_call("topic/test", 0)
|
||||||
("home/sensor", 2),
|
mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2)
|
||||||
("still/pending", 1),
|
mqtt_client_mock.subscribe.assert_any_call("still/pending", 1)
|
||||||
("topic/test", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_entry_with_config_override(
|
async def test_setup_entry_with_config_override(
|
||||||
|
|
Loading…
Add table
Reference in a new issue