diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index ce538a6af13..3ae880d2b83 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -90,6 +90,7 @@ _LOGGER = logging.getLogger(__name__) DISCOVERY_COOLDOWN = 2 INITIAL_SUBSCRIBE_COOLDOWN = 1.0 SUBSCRIBE_COOLDOWN = 0.1 +UNSUBSCRIBE_COOLDOWN = 0.1 TIMEOUT_ACK = 10 SubscribePayloadType = str | bytes # Only bytes if encoding is None @@ -387,6 +388,10 @@ class MQTT: ) self._max_qos: dict[str, int] = {} # topic, max qos self._pending_subscriptions: dict[str, int] = {} # topic, qos + self._unsubscribe_debouncer = EnsureJobAfterCooldown( + UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes + ) + self._pending_unsubscribes: set[str] = set() # topic if self.hass.state == CoreState.running: self._ha_started.set() @@ -510,6 +515,10 @@ class MQTT: await self._subscribe_debouncer.async_cleanup() # reset timeout to initial subscribe cooldown self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN) + # stop the unsubscribe debouncer + await self._unsubscribe_debouncer.async_cleanup() + # make sure the unsubscribes are processed + await self._async_perform_unsubscribes() # wait for ACKs to be processed async with self._pending_operations_condition: @@ -573,6 +582,9 @@ class MQTT: max_qos = max(qos, self._max_qos.setdefault(topic, qos)) self._max_qos[topic] = max_qos self._pending_subscriptions[topic] = max_qos + # Cancel any pending unsubscribe since we are subscribing now + if topic in self._pending_unsubscribes: + self._pending_unsubscribes.remove(topic) if queue_only: return self._subscribe_debouncer.async_schedule() @@ -608,22 +620,13 @@ class MQTT: self._matching_subscriptions.cache_clear() # Only unsubscribe if currently connected if self.connected: - self.hass.async_create_task(self._async_unsubscribe(topic)) + self._async_unsubscribe(topic) return async_remove - async def _async_unsubscribe(self, topic: str) -> None: - """Unsubscribe from a topic. - - This method is a coroutine. - """ - - def _client_unsubscribe(topic: str) -> int: - result, mid = self._mqttc.unsubscribe(topic) - _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) - _raise_on_error(result) - return mid - + @callback + def _async_unsubscribe(self, topic: str) -> None: + """Unsubscribe from a topic.""" if self._is_active_subscription(topic): if self._max_qos[topic] == 0: return @@ -636,11 +639,9 @@ class MQTT: if topic in self._pending_subscriptions: # avoid any pending subscription to be executed del self._pending_subscriptions[topic] - async with self._paho_lock: - mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic) - await self._register_mid(mid) - self.hass.async_create_task(self._wait_for_mid(mid)) + self._pending_unsubscribes.add(topic) + self._unsubscribe_debouncer.async_schedule() async def _async_perform_subscriptions(self) -> None: """Perform MQTT client subscriptions.""" @@ -677,6 +678,24 @@ class MQTT: else: _raise_on_error(result) + async def _async_perform_unsubscribes(self) -> None: + """Perform pending MQTT client unsubscribes.""" + if not self._pending_unsubscribes: + return + + topics = list(self._pending_unsubscribes) + self._pending_unsubscribes = set() + + async with self._paho_lock: + result, mid = await self.hass.async_add_executor_job( + self._mqttc.unsubscribe, topics + ) + _raise_on_error(result) + for topic in topics: + _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) + + await self._wait_for_mid(mid) + def _mqtt_on_connect( self, _mqttc: mqtt.Client, diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 800809f15ad..8d3c43744fc 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -1,4 +1,5 @@ """The tests for the MQTT discovery.""" +import asyncio import copy import json from pathlib import Path @@ -1376,6 +1377,7 @@ async def test_complex_discovery_topic_prefix( @patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0) async def test_mqtt_integration_discovery_subscribe_unsubscribe( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1407,15 +1409,18 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( return self.async_abort(reason="already_configured") with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + await asyncio.sleep(0.1) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert not mqtt_client_mock.unsubscribe.called async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + await asyncio.sleep(0.1) await hass.async_block_till_done() - mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#") + mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"]) mqtt_client_mock.unsubscribe.reset_mock() async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + await asyncio.sleep(0.1) await hass.async_block_till_done() assert not mqtt_client_mock.unsubscribe.called @@ -1423,6 +1428,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( @patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0) async def test_mqtt_discovery_unsubscribe_once( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1456,9 +1462,10 @@ async def test_mqtt_discovery_unsubscribe_once( with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") + await asyncio.sleep(0.1) await hass.async_block_till_done() await hass.async_block_till_done() - mqtt_client_mock.unsubscribe.assert_called_once_with("comp/discovery/#") + mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.SENSOR]) diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index dec24128cc6..5494b24c398 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -906,6 +906,45 @@ async def test_subscribe_topic( unsub() +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2) +async def test_subscribe_and_resubscribe( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + mqtt_client_mock: MqttMockPahoClient, + calls: list[ReceiveMessage], + record_calls: MessageCallbackType, +) -> None: + """Test resubscribing within the debounce time.""" + mqtt_mock = await mqtt_mock_entry() + # Fake that the client is connected + mqtt_mock().connected = True + + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) + # This unsub will be un-done with the following subscribe + # unsubscribe should not be called at the broker + unsub() + await asyncio.sleep(0.1) + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) + await asyncio.sleep(0.1) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].topic == "test-topic" + assert calls[0].payload == "test-payload" + # assert unsubscribe was not called + mqtt_client_mock.unsubscribe.assert_not_called() + + unsub() + + await asyncio.sleep(0.2) + await hass.async_block_till_done() + mqtt_client_mock.unsubscribe.assert_called_once_with(["test-topic"]) + + async def test_subscribe_topic_non_async( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator,