Fix MQTT race awaiting an ACK when disconnecting (#75117)
Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
11e7ddaa71
commit
5ef92e5e95
2 changed files with 72 additions and 28 deletions
|
@ -325,11 +325,11 @@ class MQTT:
|
|||
self._ha_started = asyncio.Event()
|
||||
self._last_subscribe = time.time()
|
||||
self._mqttc: mqtt.Client = None
|
||||
self._paho_lock = asyncio.Lock()
|
||||
self._pending_acks: set[int] = set()
|
||||
self._cleanup_on_unload: list[Callable] = []
|
||||
|
||||
self._pending_operations: dict[str, asyncio.Event] = {}
|
||||
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
|
||||
self._pending_operations: dict[int, asyncio.Event] = {}
|
||||
self._pending_operations_condition = asyncio.Condition()
|
||||
|
||||
if self.hass.state == CoreState.running:
|
||||
self._ha_started.set()
|
||||
|
@ -431,13 +431,13 @@ class MQTT:
|
|||
# Do not disconnect, we want the broker to always publish will
|
||||
self._mqttc.loop_stop()
|
||||
|
||||
# wait for ACK-s to be processes (unsubscribe only)
|
||||
async with self._paho_lock:
|
||||
tasks = [
|
||||
self.hass.async_create_task(self._wait_for_mid(mid))
|
||||
for mid in self._pending_acks
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
def no_more_acks() -> bool:
|
||||
"""Return False if there are unprocessed ACKs."""
|
||||
return not bool(self._pending_operations)
|
||||
|
||||
# wait for ACK-s to be processesed (unsubscribe only)
|
||||
async with self._pending_operations_condition:
|
||||
await self._pending_operations_condition.wait_for(no_more_acks)
|
||||
|
||||
# stop the MQTT loop
|
||||
await self.hass.async_add_executor_job(stop)
|
||||
|
@ -487,19 +487,21 @@ class MQTT:
|
|||
This method is a coroutine.
|
||||
"""
|
||||
|
||||
def _client_unsubscribe(topic: str) -> None:
|
||||
def _client_unsubscribe(topic: str) -> int:
|
||||
result: int | None = None
|
||||
result, mid = self._mqttc.unsubscribe(topic)
|
||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||
_raise_on_error(result)
|
||||
self._pending_acks.add(mid)
|
||||
return mid
|
||||
|
||||
if any(other.topic == topic for other in self.subscriptions):
|
||||
# Other subscriptions on topic remaining - don't unsubscribe.
|
||||
return
|
||||
|
||||
async with self._paho_lock:
|
||||
await self.hass.async_add_executor_job(_client_unsubscribe, topic)
|
||||
mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
|
||||
await self._register_mid(mid)
|
||||
self.hass.async_create_task(self._wait_for_mid(mid))
|
||||
|
||||
async def _async_perform_subscriptions(
|
||||
self, subscriptions: Iterable[tuple[str, int]]
|
||||
|
@ -647,14 +649,18 @@ class MQTT:
|
|||
"""Publish / Subscribe / Unsubscribe callback."""
|
||||
self.hass.add_job(self._mqtt_handle_mid, mid)
|
||||
|
||||
@callback
|
||||
def _mqtt_handle_mid(self, mid) -> None:
|
||||
async def _mqtt_handle_mid(self, mid: int) -> None:
|
||||
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
||||
# may be executed first.
|
||||
if mid not in self._pending_operations:
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
await self._register_mid(mid)
|
||||
self._pending_operations[mid].set()
|
||||
|
||||
async def _register_mid(self, mid: int) -> None:
|
||||
"""Create Event for an expected ACK."""
|
||||
async with self._pending_operations_condition:
|
||||
if mid not in self._pending_operations:
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
|
||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||
"""Disconnected callback."""
|
||||
self.connected = False
|
||||
|
@ -666,12 +672,11 @@ class MQTT:
|
|||
result_code,
|
||||
)
|
||||
|
||||
async def _wait_for_mid(self, mid):
|
||||
async def _wait_for_mid(self, mid: int) -> None:
|
||||
"""Wait for ACK from broker."""
|
||||
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
||||
# may be executed first.
|
||||
if mid not in self._pending_operations:
|
||||
self._pending_operations[mid] = asyncio.Event()
|
||||
await self._register_mid(mid)
|
||||
try:
|
||||
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
|
||||
except asyncio.TimeoutError:
|
||||
|
@ -679,11 +684,10 @@ class MQTT:
|
|||
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
|
||||
)
|
||||
finally:
|
||||
del self._pending_operations[mid]
|
||||
# Cleanup ACK sync buffer
|
||||
async with self._paho_lock:
|
||||
if mid in self._pending_acks:
|
||||
self._pending_acks.remove(mid)
|
||||
async with self._pending_operations_condition:
|
||||
# Cleanup ACK sync buffer
|
||||
del self._pending_operations[mid]
|
||||
self._pending_operations_condition.notify_all()
|
||||
|
||||
async def _discovery_cooldown(self):
|
||||
now = time.time()
|
||||
|
|
|
@ -4,7 +4,6 @@ import copy
|
|||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch
|
||||
|
||||
|
@ -47,8 +46,6 @@ from tests.common import (
|
|||
)
|
||||
from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecordCallsPartial(partial):
|
||||
"""Wrapper class for partial."""
|
||||
|
@ -141,6 +138,49 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
|
|||
assert mqtt_client_mock.loop_stop.call_count == 1
|
||||
|
||||
|
||||
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
||||
async def test_mqtt_await_ack_at_disconnect(
|
||||
hass,
|
||||
):
|
||||
"""Test if ACK is awaited correctly when disconnecting."""
|
||||
|
||||
class FakeInfo:
|
||||
"""Returns a simulated client publish response."""
|
||||
|
||||
mid = 100
|
||||
rc = 0
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().connect = MagicMock(return_value=0)
|
||||
mock_client().publish = MagicMock(return_value=FakeInfo())
|
||||
entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
assert await mqtt.async_setup_entry(hass, entry)
|
||||
mqtt_client = mock_client.return_value
|
||||
|
||||
# publish from MQTT client without awaiting
|
||||
hass.async_create_task(
|
||||
mqtt.async_publish(hass, "test-topic", "some-payload", 0, False)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# Simulate late ACK callback from client with mid 100
|
||||
mqtt_client.on_publish(0, 0, 100)
|
||||
# disconnect the MQTT client
|
||||
await hass.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
# assert the payload was sent through the client
|
||||
assert mqtt_client.publish.called
|
||||
assert mqtt_client.publish.call_args[0] == (
|
||||
"test-topic",
|
||||
"some-payload",
|
||||
0,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
async def test_publish(hass, mqtt_mock_entry_no_yaml_config):
|
||||
"""Test the publish function."""
|
||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
||||
|
|
Loading…
Add table
Reference in a new issue