Subscribe per component for MQTT discovery (#119974)

* Subscribe per component for MQTT discovery

* Use single assignment

* Handle wildcard subscriptions first

* Split subsRecription handling, update helper

* Fix help_all_subscribe_calls

* Fix import

* Fix test

* Update import order

* Undo move self._last_subscribe

* Recover removed test

* Revert not needed changes to binary_sensor platform tests

* Revert line removal

* Rework interation of discovery topics

* Reduce

* Add comment

* Move comment

* Chain subscriptions
This commit is contained in:
Jan Bouwhuis 2024-08-20 17:02:48 +02:00 committed by GitHub
parent a1e3e7f24f
commit b74aced6f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 50 additions and 26 deletions

View file

@ -111,6 +111,7 @@ UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10 TIMEOUT_ACK = 10
RECONNECT_INTERVAL_SECONDS = 10 RECONNECT_INTERVAL_SECONDS = 10
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
MAX_SUBSCRIBES_PER_CALL = 500 MAX_SUBSCRIBES_PER_CALL = 500
MAX_UNSUBSCRIBES_PER_CALL = 500 MAX_UNSUBSCRIBES_PER_CALL = 500
@ -893,14 +894,27 @@ class MQTT:
if not self._pending_subscriptions: if not self._pending_subscriptions:
return return
subscriptions: dict[str, int] = self._pending_subscriptions # Split out the wildcard subscriptions, we subscribe to them one by one
pending_subscriptions: dict[str, int] = self._pending_subscriptions
pending_wildcard_subscriptions = {
subscription.topic: pending_subscriptions.pop(subscription.topic)
for subscription in self._wildcard_subscriptions
if subscription.topic in pending_subscriptions
}
self._pending_subscriptions = {} self._pending_subscriptions = {}
subscription_list = list(subscriptions.items())
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL): for chunk in chain(
chunked_or_all(
pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL
),
chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL),
):
chunk_list = list(chunk) chunk_list = list(chunk)
if not chunk_list:
continue
result, mid = self._mqttc.subscribe(chunk_list) result, mid = self._mqttc.subscribe(chunk_list)

View file

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
import functools import functools
from itertools import chain
import logging import logging
import re import re
import time import time
@ -238,10 +239,6 @@ async def async_start( # noqa: C901
component, node_id, object_id = match.groups() component, node_id, object_id = match.groups()
if component not in SUPPORTED_COMPONENTS:
_LOGGER.warning("Integration %s is not supported", component)
return
if payload: if payload:
try: try:
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload)) discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
@ -351,9 +348,15 @@ async def async_start( # noqa: C901
0, 0,
job_type=HassJobType.Callback, job_type=HassJobType.Callback,
) )
for topic in ( for topic in chain(
f"{discovery_topic}/+/+/config", (
f"{discovery_topic}/+/+/+/config", f"{discovery_topic}/{component}/+/config"
for component in SUPPORTED_COMPONENTS
),
(
f"{discovery_topic}/{component}/+/+/config"
for component in SUPPORTED_COMPONENTS
),
) )
] ]

View file

@ -13,6 +13,7 @@ import pytest
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
from homeassistant.const import ( from homeassistant.const import (
@ -1614,8 +1615,9 @@ async def test_subscription_done_when_birth_message_is_sent(
"""Test sending birth message until initial subscription has been completed.""" """Test sending birth message until initial subscription has been completed."""
mqtt_client_mock = setup_with_birth_msg_client_mock mqtt_client_mock = setup_with_birth_msg_client_mock
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock) subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
assert ("homeassistant/+/+/config", 0) in subscribe_calls for component in SUPPORTED_COMPONENTS:
assert ("homeassistant/+/+/+/config", 0) in subscribe_calls assert (f"homeassistant/{component}/+/config", 0) in subscribe_calls
assert (f"homeassistant/{component}/+/+/config", 0) in subscribe_calls
mqtt_client_mock.publish.assert_called_with( mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False "homeassistant/status", "online", 0, False
) )

View file

@ -16,7 +16,10 @@ import yaml
from homeassistant import config as module_hass_config from homeassistant import config as module_hass_config
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt import debug_info from homeassistant.components.mqtt import debug_info
from homeassistant.components.mqtt.const import MQTT_CONNECTION_STATE from homeassistant.components.mqtt.const import (
MQTT_CONNECTION_STATE,
SUPPORTED_COMPONENTS,
)
from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED
from homeassistant.components.mqtt.models import PublishPayloadType from homeassistant.components.mqtt.models import PublishPayloadType
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
@ -75,9 +78,12 @@ type _StateDataType = list[tuple[_MqttMessageType, str | None, _AttributesType |
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]: def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
"""Test of a call.""" """Test of a call."""
all_calls = [] all_calls = []
for calls in mqtt_client_mock.subscribe.mock_calls: for call_l1 in mqtt_client_mock.subscribe.mock_calls:
for call in calls[1]: if isinstance(call_l1[1][0], list):
all_calls.extend(call) for call_l2 in call_l1[1]:
all_calls.extend(call_l2)
else:
all_calls.append(call_l1[1])
return all_calls return all_calls
@ -1178,7 +1184,10 @@ async def help_test_entity_id_update_subscriptions(
state = hass.states.get(f"{domain}.test") state = hass.states.get(f"{domain}.test")
assert state is not None assert state is not None
assert mqtt_mock.async_subscribe.call_count == len(topics) + 2 + DISCOVERY_COUNT assert (
mqtt_mock.async_subscribe.call_count
== len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT
)
for topic in topics: for topic in topics:
mqtt_mock.async_subscribe.assert_any_call( mqtt_mock.async_subscribe.assert_any_call(
topic, ANY, ANY, ANY, HassJobType.Callback topic, ANY, ANY, ANY, HassJobType.Callback

View file

@ -15,6 +15,7 @@ from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS, ABBREVIATIONS,
DEVICE_ABBREVIATIONS, DEVICE_ABBREVIATIONS,
) )
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
from homeassistant.components.mqtt.discovery import ( from homeassistant.components.mqtt.discovery import (
MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW, MQTT_DISCOVERY_NEW,
@ -73,13 +74,10 @@ async def test_subscribing_config_topic(
discovery_topic = "homeassistant" discovery_topic = "homeassistant"
await async_start(hass, discovery_topic, entry) await async_start(hass, discovery_topic, entry)
call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1] topics = [call[1][0] for call in mqtt_mock.async_subscribe.mock_calls]
assert call_args1[2] == 0 for component in SUPPORTED_COMPONENTS:
call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1] assert f"{discovery_topic}/{component}/+/config" in topics
assert call_args2[2] == 0 assert f"{discovery_topic}/{component}/+/+/config" in topics
topics = [call_args1[0], call_args2[0]]
assert discovery_topic + "/+/+/config" in topics
assert discovery_topic + "/+/+/+/config" in topics
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -198,8 +196,6 @@ async def test_only_valid_components(
await hass.async_block_till_done() await hass.async_block_till_done()
assert f"Integration {invalid_component} is not supported" in caplog.text
assert not mock_dispatcher_send.called assert not mock_dispatcher_send.called