Merge pending MQTT subscribes to a single call to the paho client (#92172)

* Merge mqtt subscribes in one call

* Cleanup

* cleanup, log outside of lock

* Remove function wrapper

* Add test that we bundle subscriptions
This commit is contained in:
Jan Bouwhuis 2023-05-08 15:37:25 +02:00 committed by GitHub
parent bafb01246a
commit 689c6fbef7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 58 deletions

View file

@ -644,7 +644,6 @@ class MQTT:
async def _async_perform_subscriptions(self) -> None: async def _async_perform_subscriptions(self) -> None:
"""Perform MQTT client subscriptions.""" """Perform MQTT client subscriptions."""
subscriptions: dict[str, int]
# Section 3.3.1.3 in the specification: # Section 3.3.1.3 in the specification:
# http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
# When sending a PUBLISH Packet to a Client the Server MUST # When sending a PUBLISH Packet to a Client the Server MUST
@ -657,36 +656,26 @@ class MQTT:
# Since we do not know if a published value is retained we need to # Since we do not know if a published value is retained we need to
# (re)subscribe, to ensure retained messages are replayed # (re)subscribe, to ensure retained messages are replayed
def _process_client_subscriptions() -> list[tuple[int, int]]: if not self._pending_subscriptions:
"""Initiate all subscriptions on the MQTT client and return the results.""" return
subscribe_result_list = []
for topic, qos in subscriptions.items():
result, mid = self._mqttc.subscribe(topic, qos)
subscribe_result_list.append((result, mid))
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
return subscribe_result_list
subscriptions = self._pending_subscriptions subscriptions: dict[str, int] = self._pending_subscriptions
self._pending_subscriptions = {} self._pending_subscriptions = {}
async with self._paho_lock: async with self._paho_lock:
results = await self.hass.async_add_executor_job( subscription_list = list(subscriptions.items())
_process_client_subscriptions result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, subscription_list
) )
for topic, qos in subscriptions.items():
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
self._last_subscribe = time.time() self._last_subscribe = time.time()
tasks: list[Coroutine[Any, Any, None]] = [] if result == 0:
errors: list[int] = [] await self._wait_for_mid(mid)
for result, mid in results: else:
if result == 0: _raise_on_error(result)
tasks.append(self._wait_for_mid(mid))
else:
errors.append(result)
if tasks:
await asyncio.gather(*tasks)
if errors:
_raise_on_errors(errors)
def _mqtt_on_connect( def _mqtt_on_connect(
self, self,
@ -904,22 +893,13 @@ class MQTT:
) )
def _raise_on_errors(result_codes: Iterable[int]) -> None: def _raise_on_error(result_code: int) -> None:
"""Raise error if error result.""" """Raise error if error result."""
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
if messages := [ if result_code and (message := mqtt.error_string(result_code)):
mqtt.error_string(result_code) raise HomeAssistantError(f"Error talking to MQTT: {message}")
for result_code in result_codes
if result_code != 0
]:
raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}")
def _raise_on_error(result_code: int) -> None:
"""Raise error if error result."""
_raise_on_errors((result_code,))
def _matcher_for_topic(subscription: str) -> Any: def _matcher_for_topic(subscription: str) -> Any:

View file

@ -73,6 +73,15 @@ _StateDataType = list[tuple[_MqttMessageType, str | None, _AttributesType | None
MQTT_YAML_SCHEMA = vol.Schema({mqtt.DOMAIN: PLATFORM_CONFIG_SCHEMA_BASE}) MQTT_YAML_SCHEMA = vol.Schema({mqtt.DOMAIN: PLATFORM_CONFIG_SCHEMA_BASE})
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
"""Test of a call."""
all_calls = []
for calls in mqtt_client_mock.subscribe.mock_calls:
for call in calls[1]:
all_calls.extend(call)
return all_calls
def help_test_validate_platform_config( def help_test_validate_platform_config(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> ConfigType | None: ) -> ConfigType | None:

View file

@ -28,7 +28,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .test_common import help_test_unload_config_entry from .test_common import help_all_subscribe_calls, help_test_unload_config_entry
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
@ -1396,7 +1396,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
class TestFlow(config_entries.ConfigFlow): class TestFlow(config_entries.ConfigFlow):
@ -1407,7 +1407,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
@ -1443,7 +1443,7 @@ async def test_mqtt_discovery_unsubscribe_once(
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
class TestFlow(config_entries.ConfigFlow): class TestFlow(config_entries.ConfigFlow):

View file

@ -36,7 +36,7 @@ from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from .test_common import help_test_validate_platform_config from .test_common import help_all_subscribe_calls, help_test_validate_platform_config
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
@ -1342,16 +1342,16 @@ async def test_unsubscribe_race(
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or # We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
# when both subscriptions were combined [subscribe] # when both subscriptions were combined [subscribe]
expected_calls_1 = [ expected_calls_1 = [
call.subscribe("test/state", 0), call.subscribe([("test/state", 0)]),
call.unsubscribe("test/state"), call.unsubscribe("test/state"),
call.subscribe("test/state", 0), call.subscribe([("test/state", 0)]),
] ]
expected_calls_2 = [ expected_calls_2 = [
call.subscribe("test/state", 0), call.subscribe([("test/state", 0)]),
call.subscribe("test/state", 0), call.subscribe([("test/state", 0)]),
] ]
expected_calls_3 = [ expected_calls_3 = [
call.subscribe("test/state", 0), call.subscribe([("test/state", 0)]),
] ]
assert mqtt_client_mock.mock_calls in ( assert mqtt_client_mock.mock_calls in (
expected_calls_1, expected_calls_1,
@ -1418,7 +1418,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
# the subscribtion with the highest QoS should survive # the subscribtion with the highest QoS should survive
expected = [ expected = [
call("test/state", 2), call([("test/state", 2)]),
] ]
assert mqtt_client_mock.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
@ -1432,7 +1432,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
expected.append(call("test/state", 1)) expected.append(call([("test/state", 1)]))
assert mqtt_client_mock.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
@ -1463,9 +1463,7 @@ async def test_subscribed_at_highest_qos(
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
assert mqtt_client_mock.subscribe.mock_calls == [ assert ("test/state", 0) in help_all_subscribe_calls(mqtt_client_mock)
call("test/state", 0),
]
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
@ -1477,9 +1475,7 @@ async def test_subscribed_at_highest_qos(
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
await hass.async_block_till_done() await hass.async_block_till_done()
# the subscribtion with the highest QoS should survive # the subscribtion with the highest QoS should survive
assert mqtt_client_mock.subscribe.mock_calls == [ assert help_all_subscribe_calls(mqtt_client_mock) == [("test/state", 2)]
call("test/state", 2),
]
async def test_reload_entry_with_restored_subscriptions( async def test_reload_entry_with_restored_subscriptions(
@ -2224,10 +2220,49 @@ async def test_mqtt_subscribes_topics_on_connect(
assert mqtt_client_mock.disconnect.call_count == 0 assert mqtt_client_mock.disconnect.call_count == 0
assert mqtt_client_mock.subscribe.call_count == 3 subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
mqtt_client_mock.subscribe.assert_any_call("topic/test", 0) assert len(subscribe_calls) == 3
mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2) assert ("topic/test", 0) in subscribe_calls
mqtt_client_mock.subscribe.assert_any_call("still/pending", 1) assert ("home/sensor", 2) in subscribe_calls
assert ("still/pending", 1) in subscribe_calls
@pytest.mark.parametrize(
"mqtt_config_entry_data",
[
{
mqtt.CONF_BROKER: "mock-broker",
mqtt.CONF_BIRTH_MESSAGE: {},
mqtt.CONF_DISCOVERY: False,
}
],
)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
async def test_mqtt_subscribes_in_single_call(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
record_calls: MessageCallbackType,
) -> None:
"""Test bundled client subscription to topic."""
mqtt_mock = await mqtt_mock_entry()
# Fake that the client is connected
mqtt_mock().connected = True
mqtt_client_mock.subscribe.reset_mock()
await mqtt.async_subscribe(hass, "topic/test", record_calls)
await mqtt.async_subscribe(hass, "home/sensor", record_calls)
await hass.async_block_till_done()
# Make sure the debouncer finishes
await asyncio.sleep(0.2)
assert mqtt_client_mock.subscribe.call_count == 1
# Assert we have a single subscription call with both subscriptions
assert mqtt_client_mock.subscribe.mock_calls[0][1][0] in [
[("topic/test", 0), ("home/sensor", 0)],
[("home/sensor", 0), ("topic/test", 0)],
]
async def test_default_entry_setting_are_applied( async def test_default_entry_setting_are_applied(