Fix MQTT race awaiting an ACK when disconnecting (#75117)

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Jan Bouwhuis 2022-07-20 11:58:54 +02:00 committed by GitHub
parent 11e7ddaa71
commit 5ef92e5e95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 28 deletions

View file

@ -325,11 +325,11 @@ class MQTT:
self._ha_started = asyncio.Event() self._ha_started = asyncio.Event()
self._last_subscribe = time.time() self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
self._pending_acks: set[int] = set()
self._cleanup_on_unload: list[Callable] = [] self._cleanup_on_unload: list[Callable] = []
self._pending_operations: dict[str, asyncio.Event] = {} self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
self._pending_operations: dict[int, asyncio.Event] = {}
self._pending_operations_condition = asyncio.Condition()
if self.hass.state == CoreState.running: if self.hass.state == CoreState.running:
self._ha_started.set() self._ha_started.set()
@ -431,13 +431,13 @@ class MQTT:
# Do not disconnect, we want the broker to always publish will # Do not disconnect, we want the broker to always publish will
self._mqttc.loop_stop() self._mqttc.loop_stop()
# wait for ACK-s to be processes (unsubscribe only) def no_more_acks() -> bool:
async with self._paho_lock: """Return False if there are unprocessed ACKs."""
tasks = [ return not bool(self._pending_operations)
self.hass.async_create_task(self._wait_for_mid(mid))
for mid in self._pending_acks # wait for ACK-s to be processesed (unsubscribe only)
] async with self._pending_operations_condition:
await asyncio.gather(*tasks) await self._pending_operations_condition.wait_for(no_more_acks)
# stop the MQTT loop # stop the MQTT loop
await self.hass.async_add_executor_job(stop) await self.hass.async_add_executor_job(stop)
@ -487,19 +487,21 @@ class MQTT:
This method is a coroutine. This method is a coroutine.
""" """
def _client_unsubscribe(topic: str) -> None: def _client_unsubscribe(topic: str) -> int:
result: int | None = None result: int | None = None
result, mid = self._mqttc.unsubscribe(topic) result, mid = self._mqttc.unsubscribe(topic)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result) _raise_on_error(result)
self._pending_acks.add(mid) return mid
if any(other.topic == topic for other in self.subscriptions): if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe. # Other subscriptions on topic remaining - don't unsubscribe.
return return
async with self._paho_lock: async with self._paho_lock:
await self.hass.async_add_executor_job(_client_unsubscribe, topic) mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
await self._register_mid(mid)
self.hass.async_create_task(self._wait_for_mid(mid))
async def _async_perform_subscriptions( async def _async_perform_subscriptions(
self, subscriptions: Iterable[tuple[str, int]] self, subscriptions: Iterable[tuple[str, int]]
@ -647,14 +649,18 @@ class MQTT:
"""Publish / Subscribe / Unsubscribe callback.""" """Publish / Subscribe / Unsubscribe callback."""
self.hass.add_job(self._mqtt_handle_mid, mid) self.hass.add_job(self._mqtt_handle_mid, mid)
@callback async def _mqtt_handle_mid(self, mid: int) -> None:
def _mqtt_handle_mid(self, mid) -> None:
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first. # may be executed first.
if mid not in self._pending_operations: await self._register_mid(mid)
self._pending_operations[mid] = asyncio.Event()
self._pending_operations[mid].set() self._pending_operations[mid].set()
async def _register_mid(self, mid: int) -> None:
"""Create Event for an expected ACK."""
async with self._pending_operations_condition:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback.""" """Disconnected callback."""
self.connected = False self.connected = False
@ -666,12 +672,11 @@ class MQTT:
result_code, result_code,
) )
async def _wait_for_mid(self, mid): async def _wait_for_mid(self, mid: int) -> None:
"""Wait for ACK from broker.""" """Wait for ACK from broker."""
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first. # may be executed first.
if mid not in self._pending_operations: await self._register_mid(mid)
self._pending_operations[mid] = asyncio.Event()
try: try:
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK) await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -679,11 +684,10 @@ class MQTT:
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid "No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
) )
finally: finally:
del self._pending_operations[mid] async with self._pending_operations_condition:
# Cleanup ACK sync buffer # Cleanup ACK sync buffer
async with self._paho_lock: del self._pending_operations[mid]
if mid in self._pending_acks: self._pending_operations_condition.notify_all()
self._pending_acks.remove(mid)
async def _discovery_cooldown(self): async def _discovery_cooldown(self):
now = time.time() now = time.time()

View file

@ -4,7 +4,6 @@ import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import json import json
import logging
import ssl import ssl
from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch
@ -47,8 +46,6 @@ from tests.common import (
) )
from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES
_LOGGER = logging.getLogger(__name__)
class RecordCallsPartial(partial): class RecordCallsPartial(partial):
"""Wrapper class for partial.""" """Wrapper class for partial."""
@ -141,6 +138,49 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
assert mqtt_client_mock.loop_stop.call_count == 1 assert mqtt_client_mock.loop_stop.call_count == 1
@patch("homeassistant.components.mqtt.PLATFORMS", [])
async def test_mqtt_await_ack_at_disconnect(
hass,
):
"""Test if ACK is awaited correctly when disconnecting."""
class FakeInfo:
"""Returns a simulated client publish response."""
mid = 100
rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect = MagicMock(return_value=0)
mock_client().publish = MagicMock(return_value=FakeInfo())
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
)
entry.add_to_hass(hass)
assert await mqtt.async_setup_entry(hass, entry)
mqtt_client = mock_client.return_value
# publish from MQTT client without awaiting
hass.async_create_task(
mqtt.async_publish(hass, "test-topic", "some-payload", 0, False)
)
await asyncio.sleep(0)
# Simulate late ACK callback from client with mid 100
mqtt_client.on_publish(0, 0, 100)
# disconnect the MQTT client
await hass.async_stop()
await hass.async_block_till_done()
# assert the payload was sent through the client
assert mqtt_client.publish.called
assert mqtt_client.publish.call_args[0] == (
"test-topic",
"some-payload",
0,
False,
)
async def test_publish(hass, mqtt_mock_entry_no_yaml_config): async def test_publish(hass, mqtt_mock_entry_no_yaml_config):
"""Test the publish function.""" """Test the publish function."""
mqtt_mock = await mqtt_mock_entry_no_yaml_config() mqtt_mock = await mqtt_mock_entry_no_yaml_config()