Fix race when unsubscribing from MQTT topics (#67376)

* Fix race when unsubscribing from MQTT topics

* Improve test
This commit is contained in:
Erik Montnemery 2022-02-28 13:19:50 +00:00 committed by GitHub
parent 0db6a0b248
commit c7d59bb272
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 4 deletions

View file

@ -967,10 +967,6 @@ class MQTT:
self.subscriptions.remove(subscription)
self._matching_subscriptions.cache_clear()
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return
# Only unsubscribe if currently connected.
if self.connected:
self.hass.async_create_task(self._async_unsubscribe(topic))
@ -982,6 +978,10 @@ class MQTT:
This method is a coroutine.
"""
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return
async with self._paho_lock:
result: int | None = None
result, mid = await self.hass.async_add_executor_job(

View file

@ -1056,6 +1056,38 @@ async def test_not_calling_unsubscribe_with_active_subscribers(
assert not mqtt_client_mock.unsubscribe.called
async def test_unsubscribe_race(hass, mqtt_client_mock, mqtt_mock):
"""Test not calling unsubscribe() when other subscribers are active."""
# Fake that the client is connected
mqtt_mock().connected = True
calls_a = MagicMock()
calls_b = MagicMock()
mqtt_client_mock.reset_mock()
unsub = await mqtt.async_subscribe(hass, "test/state", calls_a)
unsub()
await mqtt.async_subscribe(hass, "test/state", calls_b)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test/state", "online")
await hass.async_block_till_done()
assert not calls_a.called
assert calls_b.called
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe]
expected_calls_1 = [
call.subscribe("test/state", 0),
call.unsubscribe("test/state"),
call.subscribe("test/state", 0),
]
expected_calls_2 = [
call.subscribe("test/state", 0),
call.subscribe("test/state", 0),
]
assert mqtt_client_mock.mock_calls in (expected_calls_1, expected_calls_2)
@pytest.mark.parametrize(
"mqtt_config",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],