Debounce and group mqtt unsubscribes (#92201)

* Debounce MQTT unsubscribes and merge to one call

* Make _async_unsubscribe a callback

* Make sure unsubscribes are processed

* Move debug log out of lock

* Reduce calls and raise outside lock

* Cancel any unsubscribe when queing

* Copy pending unsubscribes

* Only convert topics to list once

* No copy needed

* Typo in comment
This commit is contained in:
Jan Bouwhuis 2023-05-09 16:36:19 +02:00 committed by GitHub
parent 0199c6f5b2
commit 25549eed85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 19 deletions

View file

@ -90,6 +90,7 @@ _LOGGER = logging.getLogger(__name__)
DISCOVERY_COOLDOWN = 2 DISCOVERY_COOLDOWN = 2
INITIAL_SUBSCRIBE_COOLDOWN = 1.0 INITIAL_SUBSCRIBE_COOLDOWN = 1.0
SUBSCRIBE_COOLDOWN = 0.1 SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10 TIMEOUT_ACK = 10
SubscribePayloadType = str | bytes # Only bytes if encoding is None SubscribePayloadType = str | bytes # Only bytes if encoding is None
@ -387,6 +388,10 @@ class MQTT:
) )
self._max_qos: dict[str, int] = {} # topic, max qos self._max_qos: dict[str, int] = {} # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
)
self._pending_unsubscribes: set[str] = set() # topic
if self.hass.state == CoreState.running: if self.hass.state == CoreState.running:
self._ha_started.set() self._ha_started.set()
@ -510,6 +515,10 @@ class MQTT:
await self._subscribe_debouncer.async_cleanup() await self._subscribe_debouncer.async_cleanup()
# reset timeout to initial subscribe cooldown # reset timeout to initial subscribe cooldown
self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN) self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN)
# stop the unsubscribe debouncer
await self._unsubscribe_debouncer.async_cleanup()
# make sure the unsubscribes are processed
await self._async_perform_unsubscribes()
# wait for ACKs to be processed # wait for ACKs to be processed
async with self._pending_operations_condition: async with self._pending_operations_condition:
@ -573,6 +582,9 @@ class MQTT:
max_qos = max(qos, self._max_qos.setdefault(topic, qos)) max_qos = max(qos, self._max_qos.setdefault(topic, qos))
self._max_qos[topic] = max_qos self._max_qos[topic] = max_qos
self._pending_subscriptions[topic] = max_qos self._pending_subscriptions[topic] = max_qos
# Cancel any pending unsubscribe since we are subscribing now
if topic in self._pending_unsubscribes:
self._pending_unsubscribes.remove(topic)
if queue_only: if queue_only:
return return
self._subscribe_debouncer.async_schedule() self._subscribe_debouncer.async_schedule()
@ -608,22 +620,13 @@ class MQTT:
self._matching_subscriptions.cache_clear() self._matching_subscriptions.cache_clear()
# Only unsubscribe if currently connected # Only unsubscribe if currently connected
if self.connected: if self.connected:
self.hass.async_create_task(self._async_unsubscribe(topic)) self._async_unsubscribe(topic)
return async_remove return async_remove
async def _async_unsubscribe(self, topic: str) -> None: @callback
"""Unsubscribe from a topic. def _async_unsubscribe(self, topic: str) -> None:
"""Unsubscribe from a topic."""
This method is a coroutine.
"""
def _client_unsubscribe(topic: str) -> int:
result, mid = self._mqttc.unsubscribe(topic)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result)
return mid
if self._is_active_subscription(topic): if self._is_active_subscription(topic):
if self._max_qos[topic] == 0: if self._max_qos[topic] == 0:
return return
@ -636,11 +639,9 @@ class MQTT:
if topic in self._pending_subscriptions: if topic in self._pending_subscriptions:
# avoid any pending subscription to be executed # avoid any pending subscription to be executed
del self._pending_subscriptions[topic] del self._pending_subscriptions[topic]
async with self._paho_lock:
mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
await self._register_mid(mid)
self.hass.async_create_task(self._wait_for_mid(mid)) self._pending_unsubscribes.add(topic)
self._unsubscribe_debouncer.async_schedule()
async def _async_perform_subscriptions(self) -> None: async def _async_perform_subscriptions(self) -> None:
"""Perform MQTT client subscriptions.""" """Perform MQTT client subscriptions."""
@ -677,6 +678,24 @@ class MQTT:
else: else:
_raise_on_error(result) _raise_on_error(result)
async def _async_perform_unsubscribes(self) -> None:
"""Perform pending MQTT client unsubscribes."""
if not self._pending_unsubscribes:
return
topics = list(self._pending_unsubscribes)
self._pending_unsubscribes = set()
async with self._paho_lock:
result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topics
)
_raise_on_error(result)
for topic in topics:
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
await self._wait_for_mid(mid)
def _mqtt_on_connect( def _mqtt_on_connect(
self, self,
_mqttc: mqtt.Client, _mqttc: mqtt.Client,

View file

@ -1,4 +1,5 @@
"""The tests for the MQTT discovery.""" """The tests for the MQTT discovery."""
import asyncio
import copy import copy
import json import json
from pathlib import Path from pathlib import Path
@ -1376,6 +1377,7 @@ async def test_complex_discovery_topic_prefix(
@patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.PLATFORMS", [])
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0)
async def test_mqtt_integration_discovery_subscribe_unsubscribe( async def test_mqtt_integration_discovery_subscribe_unsubscribe(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1407,15 +1409,18 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
await asyncio.sleep(0.1)
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
await asyncio.sleep(0.1)
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#") mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"])
mqtt_client_mock.unsubscribe.reset_mock() mqtt_client_mock.unsubscribe.reset_mock()
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
await asyncio.sleep(0.1)
await hass.async_block_till_done() await hass.async_block_till_done()
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
@ -1423,6 +1428,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
@patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.PLATFORMS", [])
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0)
async def test_mqtt_discovery_unsubscribe_once( async def test_mqtt_discovery_unsubscribe_once(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
@ -1456,9 +1462,10 @@ async def test_mqtt_discovery_unsubscribe_once(
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
await asyncio.sleep(0.1)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#") mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"])
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.SENSOR]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.SENSOR])

View file

@ -906,6 +906,45 @@ async def test_subscribe_topic(
unsub() unsub()
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
async def test_subscribe_and_resubscribe(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient,
calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test resubscribing within the debounce time."""
mqtt_mock = await mqtt_mock_entry()
# Fake that the client is connected
mqtt_mock().connected = True
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
# This unsub will be un-done with the following subscribe
# unsubscribe should not be called at the broker
unsub()
await asyncio.sleep(0.1)
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
await asyncio.sleep(0.1)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].topic == "test-topic"
assert calls[0].payload == "test-payload"
# assert unsubscribe was not called
mqtt_client_mock.unsubscribe.assert_not_called()
unsub()
await asyncio.sleep(0.2)
await hass.async_block_till_done()
mqtt_client_mock.unsubscribe.assert_called_once_with(["test-topic"])
async def test_subscribe_topic_non_async( async def test_subscribe_topic_non_async(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,