Debounce and group mqtt unsubscribes (#92201)
* Debounce MQTT unsubscribes and merge to one call * Make _async_unsubscribe a callback * Make sure unsubscribes are processed * Move debug log out of lock * Reduce calls and raise outside lock * Cancel any unsubscribe when queing * Copy pending unsubscribes * Only convert topics to list once * No copy needed * Typo in comment
This commit is contained in:
parent
0199c6f5b2
commit
25549eed85
3 changed files with 84 additions and 19 deletions
|
@ -90,6 +90,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
DISCOVERY_COOLDOWN = 2
|
DISCOVERY_COOLDOWN = 2
|
||||||
INITIAL_SUBSCRIBE_COOLDOWN = 1.0
|
INITIAL_SUBSCRIBE_COOLDOWN = 1.0
|
||||||
SUBSCRIBE_COOLDOWN = 0.1
|
SUBSCRIBE_COOLDOWN = 0.1
|
||||||
|
UNSUBSCRIBE_COOLDOWN = 0.1
|
||||||
TIMEOUT_ACK = 10
|
TIMEOUT_ACK = 10
|
||||||
|
|
||||||
SubscribePayloadType = str | bytes # Only bytes if encoding is None
|
SubscribePayloadType = str | bytes # Only bytes if encoding is None
|
||||||
|
@ -387,6 +388,10 @@ class MQTT:
|
||||||
)
|
)
|
||||||
self._max_qos: dict[str, int] = {} # topic, max qos
|
self._max_qos: dict[str, int] = {} # topic, max qos
|
||||||
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
||||||
|
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
||||||
|
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
|
||||||
|
)
|
||||||
|
self._pending_unsubscribes: set[str] = set() # topic
|
||||||
|
|
||||||
if self.hass.state == CoreState.running:
|
if self.hass.state == CoreState.running:
|
||||||
self._ha_started.set()
|
self._ha_started.set()
|
||||||
|
@ -510,6 +515,10 @@ class MQTT:
|
||||||
await self._subscribe_debouncer.async_cleanup()
|
await self._subscribe_debouncer.async_cleanup()
|
||||||
# reset timeout to initial subscribe cooldown
|
# reset timeout to initial subscribe cooldown
|
||||||
self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN)
|
self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN)
|
||||||
|
# stop the unsubscribe debouncer
|
||||||
|
await self._unsubscribe_debouncer.async_cleanup()
|
||||||
|
# make sure the unsubscribes are processed
|
||||||
|
await self._async_perform_unsubscribes()
|
||||||
|
|
||||||
# wait for ACKs to be processed
|
# wait for ACKs to be processed
|
||||||
async with self._pending_operations_condition:
|
async with self._pending_operations_condition:
|
||||||
|
@ -573,6 +582,9 @@ class MQTT:
|
||||||
max_qos = max(qos, self._max_qos.setdefault(topic, qos))
|
max_qos = max(qos, self._max_qos.setdefault(topic, qos))
|
||||||
self._max_qos[topic] = max_qos
|
self._max_qos[topic] = max_qos
|
||||||
self._pending_subscriptions[topic] = max_qos
|
self._pending_subscriptions[topic] = max_qos
|
||||||
|
# Cancel any pending unsubscribe since we are subscribing now
|
||||||
|
if topic in self._pending_unsubscribes:
|
||||||
|
self._pending_unsubscribes.remove(topic)
|
||||||
if queue_only:
|
if queue_only:
|
||||||
return
|
return
|
||||||
self._subscribe_debouncer.async_schedule()
|
self._subscribe_debouncer.async_schedule()
|
||||||
|
@ -608,22 +620,13 @@ class MQTT:
|
||||||
self._matching_subscriptions.cache_clear()
|
self._matching_subscriptions.cache_clear()
|
||||||
# Only unsubscribe if currently connected
|
# Only unsubscribe if currently connected
|
||||||
if self.connected:
|
if self.connected:
|
||||||
self.hass.async_create_task(self._async_unsubscribe(topic))
|
self._async_unsubscribe(topic)
|
||||||
|
|
||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
async def _async_unsubscribe(self, topic: str) -> None:
|
@callback
|
||||||
"""Unsubscribe from a topic.
|
def _async_unsubscribe(self, topic: str) -> None:
|
||||||
|
"""Unsubscribe from a topic."""
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _client_unsubscribe(topic: str) -> int:
|
|
||||||
result, mid = self._mqttc.unsubscribe(topic)
|
|
||||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
|
||||||
_raise_on_error(result)
|
|
||||||
return mid
|
|
||||||
|
|
||||||
if self._is_active_subscription(topic):
|
if self._is_active_subscription(topic):
|
||||||
if self._max_qos[topic] == 0:
|
if self._max_qos[topic] == 0:
|
||||||
return
|
return
|
||||||
|
@ -636,11 +639,9 @@ class MQTT:
|
||||||
if topic in self._pending_subscriptions:
|
if topic in self._pending_subscriptions:
|
||||||
# avoid any pending subscription to be executed
|
# avoid any pending subscription to be executed
|
||||||
del self._pending_subscriptions[topic]
|
del self._pending_subscriptions[topic]
|
||||||
async with self._paho_lock:
|
|
||||||
mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
|
|
||||||
await self._register_mid(mid)
|
|
||||||
|
|
||||||
self.hass.async_create_task(self._wait_for_mid(mid))
|
self._pending_unsubscribes.add(topic)
|
||||||
|
self._unsubscribe_debouncer.async_schedule()
|
||||||
|
|
||||||
async def _async_perform_subscriptions(self) -> None:
|
async def _async_perform_subscriptions(self) -> None:
|
||||||
"""Perform MQTT client subscriptions."""
|
"""Perform MQTT client subscriptions."""
|
||||||
|
@ -677,6 +678,24 @@ class MQTT:
|
||||||
else:
|
else:
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
|
|
||||||
|
async def _async_perform_unsubscribes(self) -> None:
|
||||||
|
"""Perform pending MQTT client unsubscribes."""
|
||||||
|
if not self._pending_unsubscribes:
|
||||||
|
return
|
||||||
|
|
||||||
|
topics = list(self._pending_unsubscribes)
|
||||||
|
self._pending_unsubscribes = set()
|
||||||
|
|
||||||
|
async with self._paho_lock:
|
||||||
|
result, mid = await self.hass.async_add_executor_job(
|
||||||
|
self._mqttc.unsubscribe, topics
|
||||||
|
)
|
||||||
|
_raise_on_error(result)
|
||||||
|
for topic in topics:
|
||||||
|
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||||
|
|
||||||
|
await self._wait_for_mid(mid)
|
||||||
|
|
||||||
def _mqtt_on_connect(
|
def _mqtt_on_connect(
|
||||||
self,
|
self,
|
||||||
_mqttc: mqtt.Client,
|
_mqttc: mqtt.Client,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""The tests for the MQTT discovery."""
|
"""The tests for the MQTT discovery."""
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -1376,6 +1377,7 @@ async def test_complex_discovery_topic_prefix(
|
||||||
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
||||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
|
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0)
|
||||||
async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mqtt_client_mock: MqttMockPahoClient,
|
mqtt_client_mock: MqttMockPahoClient,
|
||||||
|
@ -1407,15 +1409,18 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
||||||
return self.async_abort(reason="already_configured")
|
return self.async_abort(reason="already_configured")
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
||||||
assert not mqtt_client_mock.unsubscribe.called
|
assert not mqtt_client_mock.unsubscribe.called
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#")
|
mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"])
|
||||||
mqtt_client_mock.unsubscribe.reset_mock()
|
mqtt_client_mock.unsubscribe.reset_mock()
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert not mqtt_client_mock.unsubscribe.called
|
assert not mqtt_client_mock.unsubscribe.called
|
||||||
|
|
||||||
|
@ -1423,6 +1428,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
||||||
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
||||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
|
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0)
|
||||||
async def test_mqtt_discovery_unsubscribe_once(
|
async def test_mqtt_discovery_unsubscribe_once(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mqtt_client_mock: MqttMockPahoClient,
|
mqtt_client_mock: MqttMockPahoClient,
|
||||||
|
@ -1456,9 +1462,10 @@ async def test_mqtt_discovery_unsubscribe_once(
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#")
|
mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"])
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.SENSOR])
|
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.SENSOR])
|
||||||
|
|
|
@ -906,6 +906,45 @@ async def test_subscribe_topic(
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
|
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
|
||||||
|
async def test_subscribe_and_resubscribe(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
mqtt_client_mock: MqttMockPahoClient,
|
||||||
|
calls: list[ReceiveMessage],
|
||||||
|
record_calls: MessageCallbackType,
|
||||||
|
) -> None:
|
||||||
|
"""Test resubscribing within the debounce time."""
|
||||||
|
mqtt_mock = await mqtt_mock_entry()
|
||||||
|
# Fake that the client is connected
|
||||||
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
|
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
# This unsub will be un-done with the following subscribe
|
||||||
|
# unsubscribe should not be called at the broker
|
||||||
|
unsub()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].topic == "test-topic"
|
||||||
|
assert calls[0].payload == "test-payload"
|
||||||
|
# assert unsubscribe was not called
|
||||||
|
mqtt_client_mock.unsubscribe.assert_not_called()
|
||||||
|
|
||||||
|
unsub()
|
||||||
|
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
mqtt_client_mock.unsubscribe.assert_called_once_with(["test-topic"])
|
||||||
|
|
||||||
|
|
||||||
async def test_subscribe_topic_non_async(
|
async def test_subscribe_topic_non_async(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
|
Loading…
Add table
Reference in a new issue