Merge pending MQTT subscribes to a single call to the paho client (#92172)
* Merge mqtt subscribes in one call * Cleanup * cleanup, log outside of lock * Remove function wrapper * Add test that we bundle subscriptions
This commit is contained in:
parent
bafb01246a
commit
689c6fbef7
4 changed files with 82 additions and 58 deletions
|
@ -644,7 +644,6 @@ class MQTT:
|
||||||
|
|
||||||
async def _async_perform_subscriptions(self) -> None:
|
async def _async_perform_subscriptions(self) -> None:
|
||||||
"""Perform MQTT client subscriptions."""
|
"""Perform MQTT client subscriptions."""
|
||||||
subscriptions: dict[str, int]
|
|
||||||
# Section 3.3.1.3 in the specification:
|
# Section 3.3.1.3 in the specification:
|
||||||
# http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
|
# http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
|
||||||
# When sending a PUBLISH Packet to a Client the Server MUST
|
# When sending a PUBLISH Packet to a Client the Server MUST
|
||||||
|
@ -657,36 +656,26 @@ class MQTT:
|
||||||
# Since we do not know if a published value is retained we need to
|
# Since we do not know if a published value is retained we need to
|
||||||
# (re)subscribe, to ensure retained messages are replayed
|
# (re)subscribe, to ensure retained messages are replayed
|
||||||
|
|
||||||
def _process_client_subscriptions() -> list[tuple[int, int]]:
|
if not self._pending_subscriptions:
|
||||||
"""Initiate all subscriptions on the MQTT client and return the results."""
|
return
|
||||||
subscribe_result_list = []
|
|
||||||
for topic, qos in subscriptions.items():
|
|
||||||
result, mid = self._mqttc.subscribe(topic, qos)
|
|
||||||
subscribe_result_list.append((result, mid))
|
|
||||||
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
|
|
||||||
return subscribe_result_list
|
|
||||||
|
|
||||||
subscriptions = self._pending_subscriptions
|
subscriptions: dict[str, int] = self._pending_subscriptions
|
||||||
self._pending_subscriptions = {}
|
self._pending_subscriptions = {}
|
||||||
|
|
||||||
async with self._paho_lock:
|
async with self._paho_lock:
|
||||||
results = await self.hass.async_add_executor_job(
|
subscription_list = list(subscriptions.items())
|
||||||
_process_client_subscriptions
|
result, mid = await self.hass.async_add_executor_job(
|
||||||
|
self._mqttc.subscribe, subscription_list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for topic, qos in subscriptions.items():
|
||||||
|
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
|
||||||
self._last_subscribe = time.time()
|
self._last_subscribe = time.time()
|
||||||
|
|
||||||
tasks: list[Coroutine[Any, Any, None]] = []
|
if result == 0:
|
||||||
errors: list[int] = []
|
await self._wait_for_mid(mid)
|
||||||
for result, mid in results:
|
else:
|
||||||
if result == 0:
|
_raise_on_error(result)
|
||||||
tasks.append(self._wait_for_mid(mid))
|
|
||||||
else:
|
|
||||||
errors.append(result)
|
|
||||||
|
|
||||||
if tasks:
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
if errors:
|
|
||||||
_raise_on_errors(errors)
|
|
||||||
|
|
||||||
def _mqtt_on_connect(
|
def _mqtt_on_connect(
|
||||||
self,
|
self,
|
||||||
|
@ -904,22 +893,13 @@ class MQTT:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_errors(result_codes: Iterable[int]) -> None:
|
def _raise_on_error(result_code: int) -> None:
|
||||||
"""Raise error if error result."""
|
"""Raise error if error result."""
|
||||||
# pylint: disable-next=import-outside-toplevel
|
# pylint: disable-next=import-outside-toplevel
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
if messages := [
|
if result_code and (message := mqtt.error_string(result_code)):
|
||||||
mqtt.error_string(result_code)
|
raise HomeAssistantError(f"Error talking to MQTT: {message}")
|
||||||
for result_code in result_codes
|
|
||||||
if result_code != 0
|
|
||||||
]:
|
|
||||||
raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}")
|
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_error(result_code: int) -> None:
|
|
||||||
"""Raise error if error result."""
|
|
||||||
_raise_on_errors((result_code,))
|
|
||||||
|
|
||||||
|
|
||||||
def _matcher_for_topic(subscription: str) -> Any:
|
def _matcher_for_topic(subscription: str) -> Any:
|
||||||
|
|
|
@ -73,6 +73,15 @@ _StateDataType = list[tuple[_MqttMessageType, str | None, _AttributesType | None
|
||||||
MQTT_YAML_SCHEMA = vol.Schema({mqtt.DOMAIN: PLATFORM_CONFIG_SCHEMA_BASE})
|
MQTT_YAML_SCHEMA = vol.Schema({mqtt.DOMAIN: PLATFORM_CONFIG_SCHEMA_BASE})
|
||||||
|
|
||||||
|
|
||||||
|
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
|
||||||
|
"""Test of a call."""
|
||||||
|
all_calls = []
|
||||||
|
for calls in mqtt_client_mock.subscribe.mock_calls:
|
||||||
|
for call in calls[1]:
|
||||||
|
all_calls.extend(call)
|
||||||
|
return all_calls
|
||||||
|
|
||||||
|
|
||||||
def help_test_validate_platform_config(
|
def help_test_validate_platform_config(
|
||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> ConfigType | None:
|
) -> ConfigType | None:
|
||||||
|
|
|
@ -28,7 +28,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||||
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
|
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .test_common import help_test_unload_config_entry
|
from .test_common import help_all_subscribe_calls, help_test_unload_config_entry
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockConfigEntry,
|
MockConfigEntry,
|
||||||
|
@ -1396,7 +1396,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
|
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
||||||
assert not mqtt_client_mock.unsubscribe.called
|
assert not mqtt_client_mock.unsubscribe.called
|
||||||
|
|
||||||
class TestFlow(config_entries.ConfigFlow):
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
@ -1407,7 +1407,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
|
||||||
return self.async_abort(reason="already_configured")
|
return self.async_abort(reason="already_configured")
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
|
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
||||||
assert not mqtt_client_mock.unsubscribe.called
|
assert not mqtt_client_mock.unsubscribe.called
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "")
|
||||||
|
@ -1443,7 +1443,7 @@ async def test_mqtt_discovery_unsubscribe_once(
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0)
|
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
||||||
assert not mqtt_client_mock.unsubscribe.called
|
assert not mqtt_client_mock.unsubscribe.called
|
||||||
|
|
||||||
class TestFlow(config_entries.ConfigFlow):
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
|
|
@ -36,7 +36,7 @@ from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from .test_common import help_test_validate_platform_config
|
from .test_common import help_all_subscribe_calls, help_test_validate_platform_config
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockConfigEntry,
|
MockConfigEntry,
|
||||||
|
@ -1342,16 +1342,16 @@ async def test_unsubscribe_race(
|
||||||
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
|
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
|
||||||
# when both subscriptions were combined [subscribe]
|
# when both subscriptions were combined [subscribe]
|
||||||
expected_calls_1 = [
|
expected_calls_1 = [
|
||||||
call.subscribe("test/state", 0),
|
call.subscribe([("test/state", 0)]),
|
||||||
call.unsubscribe("test/state"),
|
call.unsubscribe("test/state"),
|
||||||
call.subscribe("test/state", 0),
|
call.subscribe([("test/state", 0)]),
|
||||||
]
|
]
|
||||||
expected_calls_2 = [
|
expected_calls_2 = [
|
||||||
call.subscribe("test/state", 0),
|
call.subscribe([("test/state", 0)]),
|
||||||
call.subscribe("test/state", 0),
|
call.subscribe([("test/state", 0)]),
|
||||||
]
|
]
|
||||||
expected_calls_3 = [
|
expected_calls_3 = [
|
||||||
call.subscribe("test/state", 0),
|
call.subscribe([("test/state", 0)]),
|
||||||
]
|
]
|
||||||
assert mqtt_client_mock.mock_calls in (
|
assert mqtt_client_mock.mock_calls in (
|
||||||
expected_calls_1,
|
expected_calls_1,
|
||||||
|
@ -1418,7 +1418,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||||
|
|
||||||
# the subscribtion with the highest QoS should survive
|
# the subscribtion with the highest QoS should survive
|
||||||
expected = [
|
expected = [
|
||||||
call("test/state", 2),
|
call([("test/state", 2)]),
|
||||||
]
|
]
|
||||||
assert mqtt_client_mock.subscribe.mock_calls == expected
|
assert mqtt_client_mock.subscribe.mock_calls == expected
|
||||||
|
|
||||||
|
@ -1432,7 +1432,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||||
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
|
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
expected.append(call("test/state", 1))
|
expected.append(call([("test/state", 1)]))
|
||||||
assert mqtt_client_mock.subscribe.mock_calls == expected
|
assert mqtt_client_mock.subscribe.mock_calls == expected
|
||||||
|
|
||||||
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
|
async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown
|
||||||
|
@ -1463,9 +1463,7 @@ async def test_subscribed_at_highest_qos(
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
|
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert mqtt_client_mock.subscribe.mock_calls == [
|
assert ("test/state", 0) in help_all_subscribe_calls(mqtt_client_mock)
|
||||||
call("test/state", 0),
|
|
||||||
]
|
|
||||||
mqtt_client_mock.reset_mock()
|
mqtt_client_mock.reset_mock()
|
||||||
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
|
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
@ -1477,9 +1475,7 @@ async def test_subscribed_at_highest_qos(
|
||||||
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
|
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
# the subscribtion with the highest QoS should survive
|
# the subscribtion with the highest QoS should survive
|
||||||
assert mqtt_client_mock.subscribe.mock_calls == [
|
assert help_all_subscribe_calls(mqtt_client_mock) == [("test/state", 2)]
|
||||||
call("test/state", 2),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_reload_entry_with_restored_subscriptions(
|
async def test_reload_entry_with_restored_subscriptions(
|
||||||
|
@ -2224,10 +2220,49 @@ async def test_mqtt_subscribes_topics_on_connect(
|
||||||
|
|
||||||
assert mqtt_client_mock.disconnect.call_count == 0
|
assert mqtt_client_mock.disconnect.call_count == 0
|
||||||
|
|
||||||
assert mqtt_client_mock.subscribe.call_count == 3
|
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
|
||||||
mqtt_client_mock.subscribe.assert_any_call("topic/test", 0)
|
assert len(subscribe_calls) == 3
|
||||||
mqtt_client_mock.subscribe.assert_any_call("home/sensor", 2)
|
assert ("topic/test", 0) in subscribe_calls
|
||||||
mqtt_client_mock.subscribe.assert_any_call("still/pending", 1)
|
assert ("home/sensor", 2) in subscribe_calls
|
||||||
|
assert ("still/pending", 1) in subscribe_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mqtt_config_entry_data",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
mqtt.CONF_BROKER: "mock-broker",
|
||||||
|
mqtt.CONF_BIRTH_MESSAGE: {},
|
||||||
|
mqtt.CONF_DISCOVERY: False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
|
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
|
async def test_mqtt_subscribes_in_single_call(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mqtt_client_mock: MqttMockPahoClient,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
record_calls: MessageCallbackType,
|
||||||
|
) -> None:
|
||||||
|
"""Test bundled client subscription to topic."""
|
||||||
|
mqtt_mock = await mqtt_mock_entry()
|
||||||
|
# Fake that the client is connected
|
||||||
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
|
mqtt_client_mock.subscribe.reset_mock()
|
||||||
|
await mqtt.async_subscribe(hass, "topic/test", record_calls)
|
||||||
|
await mqtt.async_subscribe(hass, "home/sensor", record_calls)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
# Make sure the debouncer finishes
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
assert mqtt_client_mock.subscribe.call_count == 1
|
||||||
|
# Assert we have a single subscription call with both subscriptions
|
||||||
|
assert mqtt_client_mock.subscribe.mock_calls[0][1][0] in [
|
||||||
|
[("topic/test", 0), ("home/sensor", 0)],
|
||||||
|
[("home/sensor", 0), ("topic/test", 0)],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def test_default_entry_setting_are_applied(
|
async def test_default_entry_setting_are_applied(
|
||||||
|
|
Loading…
Add table
Reference in a new issue