From fa1ef8b0cfc1dd43b32eeb17e7a05c9f01bb6de3 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Sat, 25 May 2024 01:33:28 +0200 Subject: [PATCH] Split mqtt subscribe and unsubscribe calls to smaller chunks (#118035) --- homeassistant/components/mqtt/client.py | 39 +++++++++++------- tests/components/mqtt/test_init.py | 55 ++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index e906c4df91b..d62e66f8b0a 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -34,6 +34,7 @@ from homeassistant.helpers.start import async_at_started from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from homeassistant.util.async_ import create_eager_task +from homeassistant.util.collection import chunked_or_all from homeassistant.util.logging import catch_log_exception from .const import ( @@ -100,6 +101,9 @@ UNSUBSCRIBE_COOLDOWN = 0.1 TIMEOUT_ACK = 10 RECONNECT_INTERVAL_SECONDS = 10 +MAX_SUBSCRIBES_PER_CALL = 500 +MAX_UNSUBSCRIBES_PER_CALL = 500 + type SocketType = socket.socket | ssl.SSLSocket | mqtt.WebsocketWrapper | Any type SubscribePayloadType = str | bytes # Only bytes if encoding is None @@ -905,17 +909,21 @@ class MQTT: self._pending_subscriptions = {} subscription_list = list(subscriptions.items()) - result, mid = self._mqttc.subscribe(subscription_list) - if _LOGGER.isEnabledFor(logging.DEBUG): - for topic, qos in subscriptions.items(): - _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) - self._last_subscribe = time.monotonic() + for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL): + result, mid = self._mqttc.subscribe(chunk) - if result == 0: - await self._async_wait_for_mid(mid) - else: - _raise_on_error(result) + if _LOGGER.isEnabledFor(logging.DEBUG): + for topic, qos in subscriptions.items(): + _LOGGER.debug( + "Subscribing to %s, mid: %s, qos: %s", topic, mid, qos + ) + self._last_subscribe = time.monotonic() + + if result == 0: + await self._async_wait_for_mid(mid) + else: + _raise_on_error(result) async def _async_perform_unsubscribes(self) -> None: """Perform pending MQTT client unsubscribes.""" @@ -925,13 +933,14 @@ class MQTT: topics = list(self._pending_unsubscribes) self._pending_unsubscribes = set() - result, mid = self._mqttc.unsubscribe(topics) - _raise_on_error(result) - if _LOGGER.isEnabledFor(logging.DEBUG): - for topic in topics: - _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) + for chunk in chunked_or_all(topics, MAX_UNSUBSCRIBES_PER_CALL): + result, mid = self._mqttc.unsubscribe(chunk) + _raise_on_error(result) + if _LOGGER.isEnabledFor(logging.DEBUG): + for topic in chunk: + _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) - await self._async_wait_for_mid(mid) + await self._async_wait_for_mid(mid) async def _async_resubscribe_and_publish_birth_message( self, birth_message: PublishMessage diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 6a744b8edfb..be1304cdbfe 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -44,7 +44,7 @@ from homeassistant.const import ( UnitOfTemperature, ) import homeassistant.core as ha -from homeassistant.core import CoreState, HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.helpers import device_registry as dr, entity_registry as er, template from homeassistant.helpers.entity import Entity @@ -2816,6 +2816,59 @@ async def test_mqtt_subscribes_in_single_call( ] +@pytest.mark.parametrize( + "mqtt_config_entry_data", + [ + { + mqtt.CONF_BROKER: "mock-broker", + mqtt.CONF_BIRTH_MESSAGE: {}, + mqtt.CONF_DISCOVERY: False, + } + ], +) +@patch("homeassistant.components.mqtt.client.MAX_SUBSCRIBES_PER_CALL", 2) +@patch("homeassistant.components.mqtt.client.MAX_UNSUBSCRIBES_PER_CALL", 2) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +async def test_mqtt_subscribes_and_unsubscribes_in_chunks( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry: MqttMockHAClientGenerator, + record_calls: MessageCallbackType, +) -> None: + """Test chunked client subscriptions.""" + mqtt_mock = await mqtt_mock_entry() + # Fake that the client is connected + mqtt_mock().connected = True + + mqtt_client_mock.subscribe.reset_mock() + unsub_tasks: list[CALLBACK_TYPE] = [] + unsub_tasks.append(await mqtt.async_subscribe(hass, "topic/test1", record_calls)) + unsub_tasks.append(await mqtt.async_subscribe(hass, "home/sensor1", record_calls)) + unsub_tasks.append(await mqtt.async_subscribe(hass, "topic/test2", record_calls)) + unsub_tasks.append(await mqtt.async_subscribe(hass, "home/sensor2", record_calls)) + await hass.async_block_till_done() + # Make sure the debouncer finishes + await asyncio.sleep(0.2) + + assert mqtt_client_mock.subscribe.call_count == 2 + # Assert we have a 2 subscription calls with both 2 subscriptions + assert len(mqtt_client_mock.subscribe.mock_calls[0][1][0]) == 2 + assert len(mqtt_client_mock.subscribe.mock_calls[1][1][0]) == 2 + + # Unsubscribe all topics + for task in unsub_tasks: + task() + await hass.async_block_till_done() + # Make sure the debouncer finishes + await asyncio.sleep(0.2) + + assert mqtt_client_mock.unsubscribe.call_count == 2 + # Assert we have a 2 unsubscribe calls with both 2 topic + assert len(mqtt_client_mock.unsubscribe.mock_calls[0][1][0]) == 2 + assert len(mqtt_client_mock.unsubscribe.mock_calls[1][1][0]) == 2 + + async def test_default_entry_setting_are_applied( hass: HomeAssistant, device_registry: dr.DeviceRegistry,