Fix restore of MQTT subscriptions from reload (#88220)

This commit is contained in:
J. Nick Koston 2023-02-16 11:14:26 -06:00 committed by GitHub
parent d03655cb6f
commit d2277fa6db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 150 additions and 41 deletions

View file

@ -5,7 +5,7 @@ import asyncio
from collections.abc import Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
from itertools import chain, groupby
import logging
from operator import attrgetter
import ssl
@ -341,6 +341,11 @@ class MqttClientSetup:
return self._client
def _is_simple_match(topic: str) -> bool:
"""Return if a topic is a simple match."""
return not ("+" in topic or "#" in topic)
class MQTT:
"""Home Assistant MQTT client."""
@ -358,7 +363,6 @@ class MQTT:
self.hass = hass
self.config_entry = config_entry
self.conf = conf
self.subscriptions: list[Subscription] = []
self._simple_subscriptions: dict[str, list[Subscription]] = {}
self._wildcard_subscriptions: list[Subscription] = []
self.connected = False
@ -390,6 +394,14 @@ class MQTT:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
)
@property
def subscriptions(self) -> list[Subscription]:
"""Return the tracked subscriptions."""
return [
*chain.from_iterable(self._simple_subscriptions.values()),
*self._wildcard_subscriptions,
]
def cleanup(self) -> None:
"""Clean up listeners."""
while self._cleanup_on_unload:
@ -489,6 +501,50 @@ class MQTT:
async with self._paho_lock:
await self.hass.async_add_executor_job(stop)
@callback
def async_restore_tracked_subscriptions(
self, subscriptions: list[Subscription]
) -> None:
"""Restore tracked subscriptions after reload."""
for subscription in subscriptions:
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
@callback
def _async_track_subscription(self, subscription: Subscription) -> None:
"""Track a subscription.
This method does not send a SUBSCRIBE message to the broker.
The caller is responsible clearing the cache of _matching_subscriptions.
"""
if _is_simple_match(subscription.topic):
self._simple_subscriptions.setdefault(subscription.topic, []).append(
subscription
)
else:
self._wildcard_subscriptions.append(subscription)
@callback
def _async_untrack_subscription(self, subscription: Subscription) -> None:
"""Untrack a subscription.
This method does not send an UNSUBSCRIBE message to the broker.
The caller is responsible clearing the cache of _matching_subscriptions.
"""
topic = subscription.topic
try:
if _is_simple_match(topic):
simple_subscriptions = self._simple_subscriptions
simple_subscriptions[topic].remove(subscription)
if not simple_subscriptions[topic]:
del simple_subscriptions[topic]
else:
self._wildcard_subscriptions.remove(subscription)
except (KeyError, ValueError) as ex:
raise HomeAssistantError("Can't remove subscription twice") from ex
async def async_subscribe(
self,
topic: str,
@ -506,11 +562,7 @@ class MQTT:
subscription = Subscription(
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
)
self.subscriptions.append(subscription)
if _is_simple := "+" not in topic and "#" not in topic:
self._simple_subscriptions.setdefault(topic, []).append(subscription)
else:
self._wildcard_subscriptions.append(subscription)
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
# Only subscribe if currently connected.
@ -521,15 +573,7 @@ class MQTT:
@callback
def async_remove() -> None:
"""Remove subscription."""
if subscription not in self.subscriptions:
raise HomeAssistantError("Can't remove subscription twice")
self.subscriptions.remove(subscription)
if _is_simple:
self._simple_subscriptions[topic].remove(subscription)
if not self._simple_subscriptions[topic]:
del self._simple_subscriptions[topic]
else:
self._wildcard_subscriptions.remove(subscription)
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
# Only unsubscribe if currently connected
@ -636,18 +680,7 @@ class MQTT:
result_code,
)
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic")
self.hass.add_job(
self._async_perform_subscriptions,
[
# Re-subscribe with the highest requested qos
(topic, max(subscription.qos for subscription in subs))
for topic, subs in groupby(
sorted(self.subscriptions, key=keyfunc), keyfunc
)
],
)
self.hass.create_task(self._async_resubscribe())
if (
CONF_BIRTH_MESSAGE in self.conf
@ -669,6 +702,20 @@ class MQTT:
publish_birth_message(birth_message), self.hass.loop
)
async def _async_resubscribe(self) -> None:
"""Resubscribe on reconnect."""
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic")
await self._async_perform_subscriptions(
[
# Re-subscribe with the highest requested qos
(topic, max(subscription.qos for subscription in subs))
for topic, subs in groupby(
sorted(self.subscriptions, key=keyfunc), keyfunc
)
]
)
def _mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None: