Fix MQTT race awaiting an ACK when disconnecting (#75117)
Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
11e7ddaa71
commit
5ef92e5e95
2 changed files with 72 additions and 28 deletions
|
@ -325,11 +325,11 @@ class MQTT:
|
||||||
self._ha_started = asyncio.Event()
|
self._ha_started = asyncio.Event()
|
||||||
self._last_subscribe = time.time()
|
self._last_subscribe = time.time()
|
||||||
self._mqttc: mqtt.Client = None
|
self._mqttc: mqtt.Client = None
|
||||||
self._paho_lock = asyncio.Lock()
|
|
||||||
self._pending_acks: set[int] = set()
|
|
||||||
self._cleanup_on_unload: list[Callable] = []
|
self._cleanup_on_unload: list[Callable] = []
|
||||||
|
|
||||||
self._pending_operations: dict[str, asyncio.Event] = {}
|
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
|
||||||
|
self._pending_operations: dict[int, asyncio.Event] = {}
|
||||||
|
self._pending_operations_condition = asyncio.Condition()
|
||||||
|
|
||||||
if self.hass.state == CoreState.running:
|
if self.hass.state == CoreState.running:
|
||||||
self._ha_started.set()
|
self._ha_started.set()
|
||||||
|
@ -431,13 +431,13 @@ class MQTT:
|
||||||
# Do not disconnect, we want the broker to always publish will
|
# Do not disconnect, we want the broker to always publish will
|
||||||
self._mqttc.loop_stop()
|
self._mqttc.loop_stop()
|
||||||
|
|
||||||
# wait for ACK-s to be processes (unsubscribe only)
|
def no_more_acks() -> bool:
|
||||||
async with self._paho_lock:
|
"""Return False if there are unprocessed ACKs."""
|
||||||
tasks = [
|
return not bool(self._pending_operations)
|
||||||
self.hass.async_create_task(self._wait_for_mid(mid))
|
|
||||||
for mid in self._pending_acks
|
# wait for ACK-s to be processesed (unsubscribe only)
|
||||||
]
|
async with self._pending_operations_condition:
|
||||||
await asyncio.gather(*tasks)
|
await self._pending_operations_condition.wait_for(no_more_acks)
|
||||||
|
|
||||||
# stop the MQTT loop
|
# stop the MQTT loop
|
||||||
await self.hass.async_add_executor_job(stop)
|
await self.hass.async_add_executor_job(stop)
|
||||||
|
@ -487,19 +487,21 @@ class MQTT:
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _client_unsubscribe(topic: str) -> None:
|
def _client_unsubscribe(topic: str) -> int:
|
||||||
result: int | None = None
|
result: int | None = None
|
||||||
result, mid = self._mqttc.unsubscribe(topic)
|
result, mid = self._mqttc.unsubscribe(topic)
|
||||||
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
self._pending_acks.add(mid)
|
return mid
|
||||||
|
|
||||||
if any(other.topic == topic for other in self.subscriptions):
|
if any(other.topic == topic for other in self.subscriptions):
|
||||||
# Other subscriptions on topic remaining - don't unsubscribe.
|
# Other subscriptions on topic remaining - don't unsubscribe.
|
||||||
return
|
return
|
||||||
|
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
await self.hass.async_add_executor_job(_client_unsubscribe, topic)
|
mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic)
|
||||||
|
await self._register_mid(mid)
|
||||||
|
self.hass.async_create_task(self._wait_for_mid(mid))
|
||||||
|
|
||||||
async def _async_perform_subscriptions(
|
async def _async_perform_subscriptions(
|
||||||
self, subscriptions: Iterable[tuple[str, int]]
|
self, subscriptions: Iterable[tuple[str, int]]
|
||||||
|
@ -647,14 +649,18 @@ class MQTT:
|
||||||
"""Publish / Subscribe / Unsubscribe callback."""
|
"""Publish / Subscribe / Unsubscribe callback."""
|
||||||
self.hass.add_job(self._mqtt_handle_mid, mid)
|
self.hass.add_job(self._mqtt_handle_mid, mid)
|
||||||
|
|
||||||
@callback
|
async def _mqtt_handle_mid(self, mid: int) -> None:
|
||||||
def _mqtt_handle_mid(self, mid) -> None:
|
|
||||||
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
||||||
# may be executed first.
|
# may be executed first.
|
||||||
if mid not in self._pending_operations:
|
await self._register_mid(mid)
|
||||||
self._pending_operations[mid] = asyncio.Event()
|
|
||||||
self._pending_operations[mid].set()
|
self._pending_operations[mid].set()
|
||||||
|
|
||||||
|
async def _register_mid(self, mid: int) -> None:
|
||||||
|
"""Create Event for an expected ACK."""
|
||||||
|
async with self._pending_operations_condition:
|
||||||
|
if mid not in self._pending_operations:
|
||||||
|
self._pending_operations[mid] = asyncio.Event()
|
||||||
|
|
||||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||||
"""Disconnected callback."""
|
"""Disconnected callback."""
|
||||||
self.connected = False
|
self.connected = False
|
||||||
|
@ -666,12 +672,11 @@ class MQTT:
|
||||||
result_code,
|
result_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _wait_for_mid(self, mid):
|
async def _wait_for_mid(self, mid: int) -> None:
|
||||||
"""Wait for ACK from broker."""
|
"""Wait for ACK from broker."""
|
||||||
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
|
||||||
# may be executed first.
|
# may be executed first.
|
||||||
if mid not in self._pending_operations:
|
await self._register_mid(mid)
|
||||||
self._pending_operations[mid] = asyncio.Event()
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
|
await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
@ -679,11 +684,10 @@ class MQTT:
|
||||||
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
|
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
del self._pending_operations[mid]
|
async with self._pending_operations_condition:
|
||||||
# Cleanup ACK sync buffer
|
# Cleanup ACK sync buffer
|
||||||
async with self._paho_lock:
|
del self._pending_operations[mid]
|
||||||
if mid in self._pending_acks:
|
self._pending_operations_condition.notify_all()
|
||||||
self._pending_acks.remove(mid)
|
|
||||||
|
|
||||||
async def _discovery_cooldown(self):
|
async def _discovery_cooldown(self):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
|
@ -4,7 +4,6 @@ import copy
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import ssl
|
import ssl
|
||||||
from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch
|
from unittest.mock import ANY, AsyncMock, MagicMock, call, mock_open, patch
|
||||||
|
|
||||||
|
@ -47,8 +46,6 @@ from tests.common import (
|
||||||
)
|
)
|
||||||
from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES
|
from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RecordCallsPartial(partial):
|
class RecordCallsPartial(partial):
|
||||||
"""Wrapper class for partial."""
|
"""Wrapper class for partial."""
|
||||||
|
@ -141,6 +138,49 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
|
||||||
assert mqtt_client_mock.loop_stop.call_count == 1
|
assert mqtt_client_mock.loop_stop.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
||||||
|
async def test_mqtt_await_ack_at_disconnect(
|
||||||
|
hass,
|
||||||
|
):
|
||||||
|
"""Test if ACK is awaited correctly when disconnecting."""
|
||||||
|
|
||||||
|
class FakeInfo:
|
||||||
|
"""Returns a simulated client publish response."""
|
||||||
|
|
||||||
|
mid = 100
|
||||||
|
rc = 0
|
||||||
|
|
||||||
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
|
mock_client().connect = MagicMock(return_value=0)
|
||||||
|
mock_client().publish = MagicMock(return_value=FakeInfo())
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=mqtt.DOMAIN,
|
||||||
|
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
assert await mqtt.async_setup_entry(hass, entry)
|
||||||
|
mqtt_client = mock_client.return_value
|
||||||
|
|
||||||
|
# publish from MQTT client without awaiting
|
||||||
|
hass.async_create_task(
|
||||||
|
mqtt.async_publish(hass, "test-topic", "some-payload", 0, False)
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# Simulate late ACK callback from client with mid 100
|
||||||
|
mqtt_client.on_publish(0, 0, 100)
|
||||||
|
# disconnect the MQTT client
|
||||||
|
await hass.async_stop()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
# assert the payload was sent through the client
|
||||||
|
assert mqtt_client.publish.called
|
||||||
|
assert mqtt_client.publish.call_args[0] == (
|
||||||
|
"test-topic",
|
||||||
|
"some-payload",
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_publish(hass, mqtt_mock_entry_no_yaml_config):
|
async def test_publish(hass, mqtt_mock_entry_no_yaml_config):
|
||||||
"""Test the publish function."""
|
"""Test the publish function."""
|
||||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
||||||
|
|
Loading…
Add table
Reference in a new issue