diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index cd73ee8efb6..ce538a6af13 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -644,7 +644,6 @@ class MQTT: async def _async_perform_subscriptions(self) -> None: """Perform MQTT client subscriptions.""" - subscriptions: dict[str, int] # Section 3.3.1.3 in the specification: # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html # When sending a PUBLISH Packet to a Client the Server MUST @@ -657,36 +656,26 @@ class MQTT: # Since we do not know if a published value is retained we need to # (re)subscribe, to ensure retained messages are replayed - def _process_client_subscriptions() -> list[tuple[int, int]]: - """Initiate all subscriptions on the MQTT client and return the results.""" - subscribe_result_list = [] - for topic, qos in subscriptions.items(): - result, mid = self._mqttc.subscribe(topic, qos) - subscribe_result_list.append((result, mid)) - _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) - return subscribe_result_list + if not self._pending_subscriptions: + return - subscriptions = self._pending_subscriptions + subscriptions: dict[str, int] = self._pending_subscriptions self._pending_subscriptions = {} async with self._paho_lock: - results = await self.hass.async_add_executor_job( - _process_client_subscriptions + subscription_list = list(subscriptions.items()) + result, mid = await self.hass.async_add_executor_job( + self._mqttc.subscribe, subscription_list ) + + for topic, qos in subscriptions.items(): + _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) self._last_subscribe = time.time() - tasks: list[Coroutine[Any, Any, None]] = [] - errors: list[int] = [] - for result, mid in results: - if result == 0: - tasks.append(self._wait_for_mid(mid)) - else: - errors.append(result) - - if tasks: - await asyncio.gather(*tasks) - if errors: - _raise_on_errors(errors) + if result == 0: + await self._wait_for_mid(mid) + else: + _raise_on_error(result) def _mqtt_on_connect( self, @@ -904,22 +893,13 @@ class MQTT: ) -def _raise_on_errors(result_codes: Iterable[int]) -> None: +def _raise_on_error(result_code: int) -> None: """Raise error if error result.""" # pylint: disable-next=import-outside-toplevel import paho.mqtt.client as mqtt - if messages := [ - mqtt.error_string(result_code) - for result_code in result_codes - if result_code != 0 - ]: - raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}") - - -def _raise_on_error(result_code: int) -> None: - """Raise error if error result.""" - _raise_on_errors((result_code,)) + if result_code and (message := mqtt.error_string(result_code)): + raise HomeAssistantError(f"Error talking to MQTT: {message}") def _matcher_for_topic(subscription: str) -> Any: diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index 620f1e95c23..f9df3450c8d 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -73,6 +73,15 @@ _StateDataType = list[tuple[_MqttMessageType, str | None, _AttributesType | None MQTT_YAML_SCHEMA = vol.Schema({mqtt.DOMAIN: PLATFORM_CONFIG_SCHEMA_BASE}) +def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]: + """Test of a call.""" + all_calls = [] + for calls in mqtt_client_mock.subscribe.mock_calls: + for call in calls[1]: + all_calls.extend(call) + return all_calls + + def help_test_validate_platform_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType | None: diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 22cf9ecceed..800809f15ad 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -28,7 +28,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.setup import async_setup_component -from .test_common import help_test_unload_config_entry +from .test_common import help_all_subscribe_calls, help_test_unload_config_entry from tests.common import ( MockConfigEntry, @@ -1396,7 +1396,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( await hass.async_block_till_done() await hass.async_block_till_done() - mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) + assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert not mqtt_client_mock.unsubscribe.called class TestFlow(config_entries.ConfigFlow): @@ -1407,7 +1407,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( return self.async_abort(reason="already_configured") with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): - mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) + assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert not mqtt_client_mock.unsubscribe.called async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") @@ -1443,7 +1443,7 @@ async def test_mqtt_discovery_unsubscribe_once( await hass.async_block_till_done() await hass.async_block_till_done() - mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) + assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert not mqtt_client_mock.unsubscribe.called class TestFlow(config_entries.ConfigFlow): diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 498365de4a3..dec24128cc6 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -36,7 +36,7 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow -from .test_common import help_test_validate_platform_config +from .test_common import help_all_subscribe_calls, help_test_validate_platform_config from tests.common import ( MockConfigEntry, @@ -1342,16 +1342,16 @@ async def test_unsubscribe_race( # We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or # when both subscriptions were combined [subscribe] expected_calls_1 = [ - call.subscribe("test/state", 0), + call.subscribe([("test/state", 0)]), call.unsubscribe("test/state"), - call.subscribe("test/state", 0), + call.subscribe([("test/state", 0)]), ] expected_calls_2 = [ - call.subscribe("test/state", 0), - call.subscribe("test/state", 0), + call.subscribe([("test/state", 0)]), + call.subscribe([("test/state", 0)]), ] expected_calls_3 = [ - call.subscribe("test/state", 0), + call.subscribe([("test/state", 0)]), ] assert mqtt_client_mock.mock_calls in ( expected_calls_1, @@ -1418,7 +1418,7 @@ async def test_restore_all_active_subscriptions_on_reconnect( # the subscribtion with the highest QoS should survive expected = [ - call("test/state", 2), + call([("test/state", 2)]), ] assert mqtt_client_mock.subscribe.mock_calls == expected @@ -1432,7 +1432,7 @@ async def test_restore_all_active_subscriptions_on_reconnect( async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown await hass.async_block_till_done() - expected.append(call("test/state", 1)) + expected.append(call([("test/state", 1)])) assert mqtt_client_mock.subscribe.mock_calls == expected async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown @@ -1463,9 +1463,7 @@ async def test_subscribed_at_highest_qos( await hass.async_block_till_done() async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown await hass.async_block_till_done() - assert mqtt_client_mock.subscribe.mock_calls == [ - call("test/state", 0), - ] + assert ("test/state", 0) in help_all_subscribe_calls(mqtt_client_mock) mqtt_client_mock.reset_mock() async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown await hass.async_block_till_done() @@ -1477,9 +1475,7 @@ async def test_subscribed_at_highest_qos( async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown await hass.async_block_till_done() # the subscribtion with the highest QoS should survive - assert mqtt_client_mock.subscribe.mock_calls == [ - call("test/state", 2), - ] + assert help_all_subscribe_calls(mqtt_client_mock) == [("test/state", 2)] async def test_reload_entry_with_restored_subscriptions( @@ -2224,10 +2220,49 @@ async def test_mqtt_subscribes_topics_on_connect( assert mqtt_client_mock.disconnect.call_count == 0 - assert mqtt_client_mock.subscribe.call_count == 3 - mqtt_client_mock.subscribe.assert_any_call("topic/test", 0) - mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2) - mqtt_client_mock.subscribe.assert_any_call("still/pending", 1) + subscribe_calls = help_all_subscribe_calls(mqtt_client_mock) + assert len(subscribe_calls) == 3 + assert ("topic/test", 0) in subscribe_calls + assert ("home/sensor", 2) in subscribe_calls + assert ("still/pending", 1) in subscribe_calls + + +@pytest.mark.parametrize( + "mqtt_config_entry_data", + [ + { + mqtt.CONF_BROKER: "mock-broker", + mqtt.CONF_BIRTH_MESSAGE: {}, + mqtt.CONF_DISCOVERY: False, + } + ], +) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +async def test_mqtt_subscribes_in_single_call( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry: MqttMockHAClientGenerator, + record_calls: MessageCallbackType, +) -> None: + """Test bundled client subscription to topic.""" + mqtt_mock = await mqtt_mock_entry() + # Fake that the client is connected + mqtt_mock().connected = True + + mqtt_client_mock.subscribe.reset_mock() + await mqtt.async_subscribe(hass, "topic/test", record_calls) + await mqtt.async_subscribe(hass, "home/sensor", record_calls) + await hass.async_block_till_done() + # Make sure the debouncer finishes + await asyncio.sleep(0.2) + + assert mqtt_client_mock.subscribe.call_count == 1 + # Assert we have a single subscription call with both subscriptions + assert mqtt_client_mock.subscribe.mock_calls[0][1][0] in [ + [("topic/test", 0), ("home/sensor", 0)], + [("home/sensor", 0), ("topic/test", 0)], + ] async def test_default_entry_setting_are_applied(