Speed up subscribing to mqtt topics on connect (#73685)
* Speed up subscribing to mqtt topics * update tests * Remove extra function wrapper * Recover debug logging for subscriptions * Small changes and test * Update homeassistant/components/mqtt/client.py * Update client.py Co-authored-by: jbouwh <jan@jbsoft.nl> Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
54591b8ca1
commit
19b2b33037
2 changed files with 74 additions and 22 deletions
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Iterable
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
import inspect
|
import inspect
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
|
@ -430,7 +430,7 @@ class MQTT:
|
||||||
# Only subscribe if currently connected.
|
# Only subscribe if currently connected.
|
||||||
if self.connected:
|
if self.connected:
|
||||||
self._last_subscribe = time.time()
|
self._last_subscribe = time.time()
|
||||||
await self._async_perform_subscription(topic, qos)
|
await self._async_perform_subscriptions(((topic, qos),))
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove() -> None:
|
def async_remove() -> None:
|
||||||
|
@ -464,16 +464,37 @@ class MQTT:
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
await self._wait_for_mid(mid)
|
await self._wait_for_mid(mid)
|
||||||
|
|
||||||
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
async def _async_perform_subscriptions(
|
||||||
"""Perform a paho-mqtt subscription."""
|
self, subscriptions: Iterable[tuple[str, int]]
|
||||||
|
) -> None:
|
||||||
|
"""Perform MQTT client subscriptions."""
|
||||||
|
|
||||||
|
def _process_client_subscriptions() -> list[tuple[int, int]]:
|
||||||
|
"""Initiate all subscriptions on the MQTT client and return the results."""
|
||||||
|
subscribe_result_list = []
|
||||||
|
for topic, qos in subscriptions:
|
||||||
|
result, mid = self._mqttc.subscribe(topic, qos)
|
||||||
|
subscribe_result_list.append((result, mid))
|
||||||
|
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
|
||||||
|
return subscribe_result_list
|
||||||
|
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
result: int | None = None
|
results = await self.hass.async_add_executor_job(
|
||||||
result, mid = await self.hass.async_add_executor_job(
|
_process_client_subscriptions
|
||||||
self._mqttc.subscribe, topic, qos
|
|
||||||
)
|
)
|
||||||
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
|
|
||||||
_raise_on_error(result)
|
tasks = []
|
||||||
await self._wait_for_mid(mid)
|
errors = []
|
||||||
|
for result, mid in results:
|
||||||
|
if result == 0:
|
||||||
|
tasks.append(self._wait_for_mid(mid))
|
||||||
|
else:
|
||||||
|
errors.append(result)
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
if errors:
|
||||||
|
_raise_on_errors(errors)
|
||||||
|
|
||||||
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
||||||
"""On connect callback.
|
"""On connect callback.
|
||||||
|
@ -502,10 +523,16 @@ class MQTT:
|
||||||
|
|
||||||
# Group subscriptions to only re-subscribe once for each topic.
|
# Group subscriptions to only re-subscribe once for each topic.
|
||||||
keyfunc = attrgetter("topic")
|
keyfunc = attrgetter("topic")
|
||||||
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc):
|
self.hass.add_job(
|
||||||
# Re-subscribe with the highest requested qos
|
self._async_perform_subscriptions,
|
||||||
max_qos = max(subscription.qos for subscription in subs)
|
[
|
||||||
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
# Re-subscribe with the highest requested qos
|
||||||
|
(topic, max(subscription.qos for subscription in subs))
|
||||||
|
for topic, subs in groupby(
|
||||||
|
sorted(self.subscriptions, key=keyfunc), keyfunc
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
CONF_BIRTH_MESSAGE in self.conf
|
CONF_BIRTH_MESSAGE in self.conf
|
||||||
|
@ -638,15 +665,22 @@ class MQTT:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_error(result_code: int | None) -> None:
|
def _raise_on_errors(result_codes: Iterable[int | None]) -> None:
|
||||||
"""Raise error if error result."""
|
"""Raise error if error result."""
|
||||||
# pylint: disable-next=import-outside-toplevel
|
# pylint: disable-next=import-outside-toplevel
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
if result_code is not None and result_code != 0:
|
if messages := [
|
||||||
raise HomeAssistantError(
|
mqtt.error_string(result_code)
|
||||||
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
|
for result_code in result_codes
|
||||||
)
|
if result_code != 0
|
||||||
|
]:
|
||||||
|
raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_on_error(result_code: int | None) -> None:
|
||||||
|
"""Raise error if error result."""
|
||||||
|
_raise_on_errors((result_code,))
|
||||||
|
|
||||||
|
|
||||||
def _matcher_for_topic(subscription: str) -> Any:
|
def _matcher_for_topic(subscription: str) -> Any:
|
||||||
|
|
|
@ -1312,6 +1312,20 @@ async def test_publish_error(hass, caplog):
|
||||||
assert "Failed to connect to MQTT server: Out of memory." in caplog.text
|
assert "Failed to connect to MQTT server: Out of memory." in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_error(
|
||||||
|
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
|
||||||
|
):
|
||||||
|
"""Test publish error."""
|
||||||
|
await mqtt_mock_entry_no_yaml_config()
|
||||||
|
mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
with pytest.raises(HomeAssistantError):
|
||||||
|
# simulate client is not connected error before subscribing
|
||||||
|
mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None)
|
||||||
|
await mqtt.async_subscribe(hass, "some-topic", lambda *args: 0)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
async def test_handle_message_callback(
|
async def test_handle_message_callback(
|
||||||
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
|
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
|
||||||
):
|
):
|
||||||
|
@ -1424,6 +1438,7 @@ async def test_setup_mqtt_client_protocol(hass):
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2)
|
@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2)
|
||||||
|
@patch("homeassistant.components.mqtt.PLATFORMS", [])
|
||||||
async def test_handle_mqtt_timeout_on_callback(hass, caplog):
|
async def test_handle_mqtt_timeout_on_callback(hass, caplog):
|
||||||
"""Test publish without receiving an ACK callback."""
|
"""Test publish without receiving an ACK callback."""
|
||||||
mid = 0
|
mid = 0
|
||||||
|
@ -1764,9 +1779,12 @@ async def test_mqtt_subscribes_topics_on_connect(
|
||||||
|
|
||||||
assert mqtt_client_mock.disconnect.call_count == 0
|
assert mqtt_client_mock.disconnect.call_count == 0
|
||||||
|
|
||||||
expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1}
|
assert len(hass.add_job.mock_calls) == 1
|
||||||
calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls}
|
assert set(hass.add_job.mock_calls[0][1][1]) == {
|
||||||
assert calls == expected
|
("home/sensor", 2),
|
||||||
|
("still/pending", 1),
|
||||||
|
("topic/test", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_entry_with_config_override(
|
async def test_setup_entry_with_config_override(
|
||||||
|
|
Loading…
Add table
Reference in a new issue