Fix MQTT race awaiting an ACK when disconnecting (#75117)

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Jan Bouwhuis 2022-07-20 11:58:54 +02:00 committed by GitHub
parent 11e7ddaa71
commit 5ef92e5e95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 28 deletions

View file

@ -325,11 +325,11 @@ class MQTT:
self._ha_started = asyncio.Event()
self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
self._pending_acks: set[int] = set()
self._cleanup_on_unload: list[Callable] = []
self._pending_operations: dict[str, asyncio.Event] = {}
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
self._pending_operations: dict[int, asyncio.Event] = {}
self._pending_operations_condition = asyncio.Condition()
if self.hass.state == CoreState.running:
self._ha_started.set()
@ -431,13 +431,13 @@ class MQTT:
# Do not disconnect, we want the broker to always publish will
self._mqttc.loop_stop()
# wait for ACK-s to be processes (unsubscribe only)
async with self._paho_lock:
tasks = [
self.hass.async_create_task(self._wait_for_mid(mid))
for mid in self._pending_acks
]
await asyncio.gather(*tasks)
def no_more_acks() -> bool:
"""Return False if there are unprocessed ACKs."""
return not bool(self._pending_operations)
# wait for ACK-s to be processesed (unsubscribe only)
async with self._pending_operations_condition:
await self._pending_operations_condition.wait_for(no_more_acks)
# stop the MQTT loop
await self.hass.async_add_executor_job(stop)
@ -487,19 +487,21 @@ class MQTT:
This method is a coroutine.
"""
def _client_unsubscribe(topic: str) -> None:
def _client_unsubscribe(topic: str) -> int:
result: int | None = None
result, mid = self._mqttc.unsubscribe(topic)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result)
self._pending_acks.add(mid)
return mid
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return
async with self._paho_lock:
await self.hass.async_add_executor_job(_client_unsubscribe, topic)
mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
await self._register_mid(mid)
self.hass.async_create_task(self._wait_for_mid(mid))
async def _async_perform_subscriptions(
self, subscriptions: Iterable[tuple[str, int]]
@ -647,14 +649,18 @@ class MQTT:
"""Publish / Subscribe / Unsubscribe callback."""
self.hass.add_job(self._mqtt_handle_mid, mid)
@callback
def _mqtt_handle_mid(self, mid) -> None:
async def _mqtt_handle_mid(self, mid: int) -> None:
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first.
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
await self._register_mid(mid)
self._pending_operations[mid].set()
async def _register_mid(self, mid: int) -> None:
"""Create Event for an expected ACK."""
async with self._pending_operations_condition:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback."""
self.connected = False
@ -666,12 +672,11 @@ class MQTT:
result_code,
)
async def _wait_for_mid(self, mid):
async def _wait_for_mid(self, mid: int) -> None:
"""Wait for ACK from broker."""
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
# may be executed first.
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
await self._register_mid(mid)
try:
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
except asyncio.TimeoutError:
@ -679,11 +684,10 @@ class MQTT:
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
)
finally:
del self._pending_operations[mid]
# Cleanup ACK sync buffer
async with self._paho_lock:
if mid in self._pending_acks:
self._pending_acks.remove(mid)
async with self._pending_operations_condition:
# Cleanup ACK sync buffer
del self._pending_operations[mid]
self._pending_operations_condition.notify_all()
async def _discovery_cooldown(self):
now = time.time()

View file

@ -4,7 +4,6 @@ import copy
from datetime import datetime, timedelta
from functools import partial
import json
import logging
import ssl
from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch
@ -47,8 +46,6 @@ from tests.common import (
)
from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES
_LOGGER = logging.getLogger(__name__)
class RecordCallsPartial(partial):
"""Wrapper class for partial."""
@ -141,6 +138,49 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
assert mqtt_client_mock.loop_stop.call_count == 1
@patch("homeassistant.components.mqtt.PLATFORMS", [])
async def test_mqtt_await_ack_at_disconnect(
hass,
):
"""Test if ACK is awaited correctly when disconnecting."""
class FakeInfo:
"""Returns a simulated client publish response."""
mid = 100
rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect = MagicMock(return_value=0)
mock_client().publish = MagicMock(return_value=FakeInfo())
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
)
entry.add_to_hass(hass)
assert await mqtt.async_setup_entry(hass, entry)
mqtt_client = mock_client.return_value
# publish from MQTT client without awaiting
hass.async_create_task(
mqtt.async_publish(hass, "test-topic", "some-payload", 0, False)
)
await asyncio.sleep(0)
# Simulate late ACK callback from client with mid 100
mqtt_client.on_publish(0, 0, 100)
# disconnect the MQTT client
await hass.async_stop()
await hass.async_block_till_done()
# assert the payload was sent through the client
assert mqtt_client.publish.called
assert mqtt_client.publish.call_args[0] == (
"test-topic",
"some-payload",
0,
False,
)
async def test_publish(hass, mqtt_mock_entry_no_yaml_config):
"""Test the publish function."""
mqtt_mock = await mqtt_mock_entry_no_yaml_config()