diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 66699372516..d676c128260 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Iterable from functools import lru_cache, partial, wraps import inspect from itertools import groupby @@ -430,7 +430,7 @@ class MQTT: # Only subscribe if currently connected. if self.connected: self._last_subscribe = time.time() - await self._async_perform_subscription(topic, qos) + await self._async_perform_subscriptions(((topic, qos),)) @callback def async_remove() -> None: @@ -464,16 +464,37 @@ class MQTT: _raise_on_error(result) await self._wait_for_mid(mid) - async def _async_perform_subscription(self, topic: str, qos: int) -> None: - """Perform a paho-mqtt subscription.""" + async def _async_perform_subscriptions( + self, subscriptions: Iterable[tuple[str, int]] + ) -> None: + """Perform MQTT client subscriptions.""" + + def _process_client_subscriptions() -> list[tuple[int, int]]: + """Initiate all subscriptions on the MQTT client and return the results.""" + subscribe_result_list = [] + for topic, qos in subscriptions: + result, mid = self._mqttc.subscribe(topic, qos) + subscribe_result_list.append((result, mid)) + _LOGGER.debug("Subscribing to %s, mid: %s", topic, mid) + return subscribe_result_list + async with self._paho_lock: - result: int | None = None - result, mid = await self.hass.async_add_executor_job( - self._mqttc.subscribe, topic, qos + results = await self.hass.async_add_executor_job( + _process_client_subscriptions ) - _LOGGER.debug("Subscribing to %s, mid: %s", topic, mid) - _raise_on_error(result) - await self._wait_for_mid(mid) + + tasks = [] + errors = [] + for result, mid in results: + if result == 0: + tasks.append(self._wait_for_mid(mid)) + else: + errors.append(result) + + if tasks: + await asyncio.gather(*tasks) + if errors: + _raise_on_errors(errors) def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None: """On connect callback. @@ -502,10 +523,16 @@ class MQTT: # Group subscriptions to only re-subscribe once for each topic. keyfunc = attrgetter("topic") - for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc): - # Re-subscribe with the highest requested qos - max_qos = max(subscription.qos for subscription in subs) - self.hass.add_job(self._async_perform_subscription, topic, max_qos) + self.hass.add_job( + self._async_perform_subscriptions, + [ + # Re-subscribe with the highest requested qos + (topic, max(subscription.qos for subscription in subs)) + for topic, subs in groupby( + sorted(self.subscriptions, key=keyfunc), keyfunc + ) + ], + ) if ( CONF_BIRTH_MESSAGE in self.conf @@ -638,15 +665,22 @@ class MQTT: ) -def _raise_on_error(result_code: int | None) -> None: +def _raise_on_errors(result_codes: Iterable[int | None]) -> None: """Raise error if error result.""" # pylint: disable-next=import-outside-toplevel import paho.mqtt.client as mqtt - if result_code is not None and result_code != 0: - raise HomeAssistantError( - f"Error talking to MQTT: {mqtt.error_string(result_code)}" - ) + if messages := [ + mqtt.error_string(result_code) + for result_code in result_codes + if result_code != 0 + ]: + raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}") + + +def _raise_on_error(result_code: int | None) -> None: + """Raise error if error result.""" + _raise_on_errors((result_code,)) def _matcher_for_topic(subscription: str) -> Any: diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index a29f1fd88ef..b435798c241 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1312,6 +1312,20 @@ async def test_publish_error(hass, caplog): assert "Failed to connect to MQTT server: Out of memory." in caplog.text +async def test_subscribe_error( + hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock +): + """Test publish error.""" + await mqtt_mock_entry_no_yaml_config() + mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0) + await hass.async_block_till_done() + with pytest.raises(HomeAssistantError): + # simulate client is not connected error before subscribing + mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None) + await mqtt.async_subscribe(hass, "some-topic", lambda *args: 0) + await hass.async_block_till_done() + + async def test_handle_message_callback( hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock ): @@ -1424,6 +1438,7 @@ async def test_setup_mqtt_client_protocol(hass): @patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2) +@patch("homeassistant.components.mqtt.PLATFORMS", []) async def test_handle_mqtt_timeout_on_callback(hass, caplog): """Test publish without receiving an ACK callback.""" mid = 0 @@ -1764,9 +1779,12 @@ async def test_mqtt_subscribes_topics_on_connect( assert mqtt_client_mock.disconnect.call_count == 0 - expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1} - calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls} - assert calls == expected + assert len(hass.add_job.mock_calls) == 1 + assert set(hass.add_job.mock_calls[0][1][1]) == { + ("home/sensor", 2), + ("still/pending", 1), + ("topic/test", 0), + } async def test_setup_entry_with_config_override(