Speed up subscribing to mqtt topics on connect (#73685)
* Speed up subscribing to mqtt topics * update tests * Remove extra function wrapper * Recover debug logging for subscriptions * Small changes and test * Update homeassistant/components/mqtt/client.py * Update client.py Co-authored-by: jbouwh <jan@jbsoft.nl> Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
54591b8ca1
commit
19b2b33037
2 changed files with 74 additions and 22 deletions
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import lru_cache, partial, wraps
|
||||
import inspect
|
||||
from itertools import groupby
|
||||
|
@ -430,7 +430,7 @@ class MQTT:
|
|||
# Only subscribe if currently connected.
|
||||
if self.connected:
|
||||
self._last_subscribe = time.time()
|
||||
await self._async_perform_subscription(topic, qos)
|
||||
await self._async_perform_subscriptions(((topic, qos),))
|
||||
|
||||
@callback
|
||||
def async_remove() -> None:
|
||||
|
@ -464,16 +464,37 @@ class MQTT:
|
|||
_raise_on_error(result)
|
||||
await self._wait_for_mid(mid)
|
||||
|
||||
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
||||
"""Perform a paho-mqtt subscription."""
|
||||
async def _async_perform_subscriptions(
|
||||
self, subscriptions: Iterable[tuple[str, int]]
|
||||
) -> None:
|
||||
"""Perform MQTT client subscriptions."""
|
||||
|
||||
def _process_client_subscriptions() -> list[tuple[int, int]]:
|
||||
"""Initiate all subscriptions on the MQTT client and return the results."""
|
||||
subscribe_result_list = []
|
||||
for topic, qos in subscriptions:
|
||||
result, mid = self._mqttc.subscribe(topic, qos)
|
||||
subscribe_result_list.append((result, mid))
|
||||
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
|
||||
return subscribe_result_list
|
||||
|
||||
async with self._paho_lock:
|
||||
result: int | None = None
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.subscribe, topic, qos
|
||||
results = await self.hass.async_add_executor_job(
|
||||
_process_client_subscriptions
|
||||
)
|
||||
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
|
||||
_raise_on_error(result)
|
||||
await self._wait_for_mid(mid)
|
||||
|
||||
tasks = []
|
||||
errors = []
|
||||
for result, mid in results:
|
||||
if result == 0:
|
||||
tasks.append(self._wait_for_mid(mid))
|
||||
else:
|
||||
errors.append(result)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
if errors:
|
||||
_raise_on_errors(errors)
|
||||
|
||||
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
||||
"""On connect callback.
|
||||
|
@ -502,10 +523,16 @@ class MQTT:
|
|||
|
||||
# Group subscriptions to only re-subscribe once for each topic.
|
||||
keyfunc = attrgetter("topic")
|
||||
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc):
|
||||
# Re-subscribe with the highest requested qos
|
||||
max_qos = max(subscription.qos for subscription in subs)
|
||||
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
||||
self.hass.add_job(
|
||||
self._async_perform_subscriptions,
|
||||
[
|
||||
# Re-subscribe with the highest requested qos
|
||||
(topic, max(subscription.qos for subscription in subs))
|
||||
for topic, subs in groupby(
|
||||
sorted(self.subscriptions, key=keyfunc), keyfunc
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if (
|
||||
CONF_BIRTH_MESSAGE in self.conf
|
||||
|
@ -638,15 +665,22 @@ class MQTT:
|
|||
)
|
||||
|
||||
|
||||
def _raise_on_error(result_code: int | None) -> None:
|
||||
def _raise_on_errors(result_codes: Iterable[int | None]) -> None:
|
||||
"""Raise error if error result."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
if result_code is not None and result_code != 0:
|
||||
raise HomeAssistantError(
|
||||
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
|
||||
)
|
||||
if messages := [
|
||||
mqtt.error_string(result_code)
|
||||
for result_code in result_codes
|
||||
if result_code != 0
|
||||
]:
|
||||
raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}")
|
||||
|
||||
|
||||
def _raise_on_error(result_code: int | None) -> None:
|
||||
"""Raise error if error result."""
|
||||
_raise_on_errors((result_code,))
|
||||
|
||||
|
||||
def _matcher_for_topic(subscription: str) -> Any:
|
||||
|
|
|
@ -1312,6 +1312,20 @@ async def test_publish_error(hass, caplog):
|
|||
assert "Failed to connect to MQTT server: Out of memory." in caplog.text
|
||||
|
||||
|
||||
async def test_subscribe_error(
|
||||
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
|
||||
):
|
||||
"""Test publish error."""
|
||||
await mqtt_mock_entry_no_yaml_config()
|
||||
mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0)
|
||||
await hass.async_block_till_done()
|
||||
with pytest.raises(HomeAssistantError):
|
||||
# simulate client is not connected error before subscribing
|
||||
mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None)
|
||||
await mqtt.async_subscribe(hass, "some-topic", lambda *args: 0)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_handle_message_callback(
|
||||
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
|
||||
):
|
||||
|
@ -1424,6 +1438,7 @@ async def test_setup_mqtt_client_protocol(hass):
|
|||
|
||||
|
||||
@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2)
|
||||
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
||||
async def test_handle_mqtt_timeout_on_callback(hass, caplog):
|
||||
"""Test publish without receiving an ACK callback."""
|
||||
mid = 0
|
||||
|
@ -1764,9 +1779,12 @@ async def test_mqtt_subscribes_topics_on_connect(
|
|||
|
||||
assert mqtt_client_mock.disconnect.call_count == 0
|
||||
|
||||
expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1}
|
||||
calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls}
|
||||
assert calls == expected
|
||||
assert len(hass.add_job.mock_calls) == 1
|
||||
assert set(hass.add_job.mock_calls[0][1][1]) == {
|
||||
("home/sensor", 2),
|
||||
("still/pending", 1),
|
||||
("topic/test", 0),
|
||||
}
|
||||
|
||||
|
||||
async def test_setup_entry_with_config_override(
|
||||
|
|
Loading…
Add table
Reference in a new issue