diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 1982d1f3df5..107bc4660c2 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -967,10 +967,6 @@ class MQTT: self.subscriptions.remove(subscription) self._matching_subscriptions.cache_clear() - if any(other.topic == topic for other in self.subscriptions): - # Other subscriptions on topic remaining - don't unsubscribe. - return - # Only unsubscribe if currently connected. if self.connected: self.hass.async_create_task(self._async_unsubscribe(topic)) @@ -982,6 +978,10 @@ class MQTT: This method is a coroutine. """ + if any(other.topic == topic for other in self.subscriptions): + # Other subscriptions on topic remaining - don't unsubscribe. + return + async with self._paho_lock: result: int | None = None result, mid = await self.hass.async_add_executor_job( diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index e589e447a01..7296d4e8101 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1056,6 +1056,38 @@ async def test_not_calling_unsubscribe_with_active_subscribers( assert not mqtt_client_mock.unsubscribe.called +async def test_unsubscribe_race(hass, mqtt_client_mock, mqtt_mock): + """Test not calling unsubscribe() when other subscribers are active.""" + # Fake that the client is connected + mqtt_mock().connected = True + + calls_a = MagicMock() + calls_b = MagicMock() + + mqtt_client_mock.reset_mock() + unsub = await mqtt.async_subscribe(hass, "test/state", calls_a) + unsub() + await mqtt.async_subscribe(hass, "test/state", calls_b) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "test/state", "online") + await hass.async_block_till_done() + assert not calls_a.called + assert calls_b.called + + # We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe] + expected_calls_1 = [ + call.subscribe("test/state", 0), + call.unsubscribe("test/state"), + call.subscribe("test/state", 0), + ] + expected_calls_2 = [ + call.subscribe("test/state", 0), + call.subscribe("test/state", 0), + ] + assert mqtt_client_mock.mock_calls in (expected_calls_1, expected_calls_2) + + @pytest.mark.parametrize( "mqtt_config", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],