From 5ef92e5e956de4aada3fbe990f371116f219e7fa Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Wed, 20 Jul 2022 11:58:54 +0200 Subject: [PATCH] Fix MQTT race awaiting an ACK when disconnecting (#75117) Co-authored-by: Erik --- homeassistant/components/mqtt/client.py | 54 +++++++++++++------------ tests/components/mqtt/test_init.py | 46 +++++++++++++++++++-- 2 files changed, 72 insertions(+), 28 deletions(-) diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 9eeed426d17..6e6f67e4e7a 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -325,11 +325,11 @@ class MQTT: self._ha_started = asyncio.Event() self._last_subscribe = time.time() self._mqttc: mqtt.Client = None - self._paho_lock = asyncio.Lock() - self._pending_acks: set[int] = set() self._cleanup_on_unload: list[Callable] = [] - self._pending_operations: dict[str, asyncio.Event] = {} + self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client + self._pending_operations: dict[int, asyncio.Event] = {} + self._pending_operations_condition = asyncio.Condition() if self.hass.state == CoreState.running: self._ha_started.set() @@ -431,13 +431,13 @@ class MQTT: # Do not disconnect, we want the broker to always publish will self._mqttc.loop_stop() - # wait for ACK-s to be processes (unsubscribe only) - async with self._paho_lock: - tasks = [ - self.hass.async_create_task(self._wait_for_mid(mid)) - for mid in self._pending_acks - ] - await asyncio.gather(*tasks) + def no_more_acks() -> bool: + """Return False if there are unprocessed ACKs.""" + return not bool(self._pending_operations) + + # wait for ACK-s to be processesed (unsubscribe only) + async with self._pending_operations_condition: + await self._pending_operations_condition.wait_for(no_more_acks) # stop the MQTT loop await self.hass.async_add_executor_job(stop) @@ -487,19 +487,21 @@ class MQTT: This method is a coroutine. """ - def _client_unsubscribe(topic: str) -> None: + def _client_unsubscribe(topic: str) -> int: result: int | None = None result, mid = self._mqttc.unsubscribe(topic) _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) _raise_on_error(result) - self._pending_acks.add(mid) + return mid if any(other.topic == topic for other in self.subscriptions): # Other subscriptions on topic remaining - don't unsubscribe. return async with self._paho_lock: - await self.hass.async_add_executor_job(_client_unsubscribe, topic) + mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic) + await self._register_mid(mid) + self.hass.async_create_task(self._wait_for_mid(mid)) async def _async_perform_subscriptions( self, subscriptions: Iterable[tuple[str, int]] @@ -647,14 +649,18 @@ class MQTT: """Publish / Subscribe / Unsubscribe callback.""" self.hass.add_job(self._mqtt_handle_mid, mid) - @callback - def _mqtt_handle_mid(self, mid) -> None: + async def _mqtt_handle_mid(self, mid: int) -> None: # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # may be executed first. - if mid not in self._pending_operations: - self._pending_operations[mid] = asyncio.Event() + await self._register_mid(mid) self._pending_operations[mid].set() + async def _register_mid(self, mid: int) -> None: + """Create Event for an expected ACK.""" + async with self._pending_operations_condition: + if mid not in self._pending_operations: + self._pending_operations[mid] = asyncio.Event() + def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: """Disconnected callback.""" self.connected = False @@ -666,12 +672,11 @@ class MQTT: result_code, ) - async def _wait_for_mid(self, mid): + async def _wait_for_mid(self, mid: int) -> None: """Wait for ACK from broker.""" # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # may be executed first. - if mid not in self._pending_operations: - self._pending_operations[mid] = asyncio.Event() + await self._register_mid(mid) try: await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK) except asyncio.TimeoutError: @@ -679,11 +684,10 @@ class MQTT: "No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid ) finally: - del self._pending_operations[mid] - # Cleanup ACK sync buffer - async with self._paho_lock: - if mid in self._pending_acks: - self._pending_acks.remove(mid) + async with self._pending_operations_condition: + # Cleanup ACK sync buffer + del self._pending_operations[mid] + self._pending_operations_condition.notify_all() async def _discovery_cooldown(self): now = time.time() diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index de63528a08b..cec6038ae04 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -4,7 +4,6 @@ import copy from datetime import datetime, timedelta from functools import partial import json -import logging import ssl from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch @@ -47,8 +46,6 @@ from tests.common import ( ) from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES -_LOGGER = logging.getLogger(__name__) - class RecordCallsPartial(partial): """Wrapper class for partial.""" @@ -141,6 +138,49 @@ async def test_mqtt_disconnects_on_home_assistant_stop( assert mqtt_client_mock.loop_stop.call_count == 1 +@patch("homeassistant.components.mqtt.PLATFORMS", []) +async def test_mqtt_await_ack_at_disconnect( + hass, +): + """Test if ACK is awaited correctly when disconnecting.""" + + class FakeInfo: + """Returns a simulated client publish response.""" + + mid = 100 + rc = 0 + + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().connect = MagicMock(return_value=0) + mock_client().publish = MagicMock(return_value=FakeInfo()) + entry = MockConfigEntry( + domain=mqtt.DOMAIN, + data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"}, + ) + entry.add_to_hass(hass) + assert await mqtt.async_setup_entry(hass, entry) + mqtt_client = mock_client.return_value + + # publish from MQTT client without awaiting + hass.async_create_task( + mqtt.async_publish(hass, "test-topic", "some-payload", 0, False) + ) + await asyncio.sleep(0) + # Simulate late ACK callback from client with mid 100 + mqtt_client.on_publish(0, 0, 100) + # disconnect the MQTT client + await hass.async_stop() + await hass.async_block_till_done() + # assert the payload was sent through the client + assert mqtt_client.publish.called + assert mqtt_client.publish.call_args[0] == ( + "test-topic", + "some-payload", + 0, + False, + ) + + async def test_publish(hass, mqtt_mock_entry_no_yaml_config): """Test the publish function.""" mqtt_mock = await mqtt_mock_entry_no_yaml_config()