Subscribe per component for MQTT discovery (#119974)
* Subscribe per component for MQTT discovery * Use single assignment * Handle wildcard subscriptions first * Split subsRecription handling, update helper * Fix help_all_subscribe_calls * Fix import * Fix test * Update import order * Undo move self._last_subscribe * Recover removed test * Revert not needed changes to binary_sensor platform tests * Revert line removal * Rework interation of discovery topics * Reduce * Add comment * Move comment * Chain subscriptions
This commit is contained in:
parent
a1e3e7f24f
commit
b74aced6f3
5 changed files with 50 additions and 26 deletions
|
@ -111,6 +111,7 @@ UNSUBSCRIBE_COOLDOWN = 0.1
|
||||||
TIMEOUT_ACK = 10
|
TIMEOUT_ACK = 10
|
||||||
RECONNECT_INTERVAL_SECONDS = 10
|
RECONNECT_INTERVAL_SECONDS = 10
|
||||||
|
|
||||||
|
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
|
||||||
MAX_SUBSCRIBES_PER_CALL = 500
|
MAX_SUBSCRIBES_PER_CALL = 500
|
||||||
MAX_UNSUBSCRIBES_PER_CALL = 500
|
MAX_UNSUBSCRIBES_PER_CALL = 500
|
||||||
|
|
||||||
|
@ -893,14 +894,27 @@ class MQTT:
|
||||||
if not self._pending_subscriptions:
|
if not self._pending_subscriptions:
|
||||||
return
|
return
|
||||||
|
|
||||||
subscriptions: dict[str, int] = self._pending_subscriptions
|
# Split out the wildcard subscriptions, we subscribe to them one by one
|
||||||
|
pending_subscriptions: dict[str, int] = self._pending_subscriptions
|
||||||
|
pending_wildcard_subscriptions = {
|
||||||
|
subscription.topic: pending_subscriptions.pop(subscription.topic)
|
||||||
|
for subscription in self._wildcard_subscriptions
|
||||||
|
if subscription.topic in pending_subscriptions
|
||||||
|
}
|
||||||
|
|
||||||
self._pending_subscriptions = {}
|
self._pending_subscriptions = {}
|
||||||
|
|
||||||
subscription_list = list(subscriptions.items())
|
|
||||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||||
|
|
||||||
for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL):
|
for chunk in chain(
|
||||||
|
chunked_or_all(
|
||||||
|
pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL
|
||||||
|
),
|
||||||
|
chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL),
|
||||||
|
):
|
||||||
chunk_list = list(chunk)
|
chunk_list = list(chunk)
|
||||||
|
if not chunk_list:
|
||||||
|
continue
|
||||||
|
|
||||||
result, mid = self._mqttc.subscribe(chunk_list)
|
result, mid = self._mqttc.subscribe(chunk_list)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import functools
|
import functools
|
||||||
|
from itertools import chain
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
@ -238,10 +239,6 @@ async def async_start( # noqa: C901
|
||||||
|
|
||||||
component, node_id, object_id = match.groups()
|
component, node_id, object_id = match.groups()
|
||||||
|
|
||||||
if component not in SUPPORTED_COMPONENTS:
|
|
||||||
_LOGGER.warning("Integration %s is not supported", component)
|
|
||||||
return
|
|
||||||
|
|
||||||
if payload:
|
if payload:
|
||||||
try:
|
try:
|
||||||
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
|
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
|
||||||
|
@ -351,9 +348,15 @@ async def async_start( # noqa: C901
|
||||||
0,
|
0,
|
||||||
job_type=HassJobType.Callback,
|
job_type=HassJobType.Callback,
|
||||||
)
|
)
|
||||||
for topic in (
|
for topic in chain(
|
||||||
f"{discovery_topic}/+/+/config",
|
(
|
||||||
f"{discovery_topic}/+/+/+/config",
|
f"{discovery_topic}/{component}/+/config"
|
||||||
|
for component in SUPPORTED_COMPONENTS
|
||||||
|
),
|
||||||
|
(
|
||||||
|
f"{discovery_topic}/{component}/+/+/config"
|
||||||
|
for component in SUPPORTED_COMPONENTS
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import pytest
|
||||||
|
|
||||||
from homeassistant.components import mqtt
|
from homeassistant.components import mqtt
|
||||||
from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS
|
from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS
|
||||||
|
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
|
||||||
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
|
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
|
||||||
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -1614,8 +1615,9 @@ async def test_subscription_done_when_birth_message_is_sent(
|
||||||
"""Test sending birth message until initial subscription has been completed."""
|
"""Test sending birth message until initial subscription has been completed."""
|
||||||
mqtt_client_mock = setup_with_birth_msg_client_mock
|
mqtt_client_mock = setup_with_birth_msg_client_mock
|
||||||
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
|
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
|
||||||
assert ("homeassistant/+/+/config", 0) in subscribe_calls
|
for component in SUPPORTED_COMPONENTS:
|
||||||
assert ("homeassistant/+/+/+/config", 0) in subscribe_calls
|
assert (f"homeassistant/{component}/+/config", 0) in subscribe_calls
|
||||||
|
assert (f"homeassistant/{component}/+/+/config", 0) in subscribe_calls
|
||||||
mqtt_client_mock.publish.assert_called_with(
|
mqtt_client_mock.publish.assert_called_with(
|
||||||
"homeassistant/status", "online", 0, False
|
"homeassistant/status", "online", 0, False
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,10 @@ import yaml
|
||||||
from homeassistant import config as module_hass_config
|
from homeassistant import config as module_hass_config
|
||||||
from homeassistant.components import mqtt
|
from homeassistant.components import mqtt
|
||||||
from homeassistant.components.mqtt import debug_info
|
from homeassistant.components.mqtt import debug_info
|
||||||
from homeassistant.components.mqtt.const import MQTT_CONNECTION_STATE
|
from homeassistant.components.mqtt.const import (
|
||||||
|
MQTT_CONNECTION_STATE,
|
||||||
|
SUPPORTED_COMPONENTS,
|
||||||
|
)
|
||||||
from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED
|
from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED
|
||||||
from homeassistant.components.mqtt.models import PublishPayloadType
|
from homeassistant.components.mqtt.models import PublishPayloadType
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
|
@ -75,9 +78,12 @@ type _StateDataType = list[tuple[_MqttMessageType, str | None, _AttributesType |
|
||||||
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
|
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
|
||||||
"""Test of a call."""
|
"""Test of a call."""
|
||||||
all_calls = []
|
all_calls = []
|
||||||
for calls in mqtt_client_mock.subscribe.mock_calls:
|
for call_l1 in mqtt_client_mock.subscribe.mock_calls:
|
||||||
for call in calls[1]:
|
if isinstance(call_l1[1][0], list):
|
||||||
all_calls.extend(call)
|
for call_l2 in call_l1[1]:
|
||||||
|
all_calls.extend(call_l2)
|
||||||
|
else:
|
||||||
|
all_calls.append(call_l1[1])
|
||||||
return all_calls
|
return all_calls
|
||||||
|
|
||||||
|
|
||||||
|
@ -1178,7 +1184,10 @@ async def help_test_entity_id_update_subscriptions(
|
||||||
|
|
||||||
state = hass.states.get(f"{domain}.test")
|
state = hass.states.get(f"{domain}.test")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
assert mqtt_mock.async_subscribe.call_count == len(topics) + 2 + DISCOVERY_COUNT
|
assert (
|
||||||
|
mqtt_mock.async_subscribe.call_count
|
||||||
|
== len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT
|
||||||
|
)
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
mqtt_mock.async_subscribe.assert_any_call(
|
mqtt_mock.async_subscribe.assert_any_call(
|
||||||
topic, ANY, ANY, ANY, HassJobType.Callback
|
topic, ANY, ANY, ANY, HassJobType.Callback
|
||||||
|
|
|
@ -15,6 +15,7 @@ from homeassistant.components.mqtt.abbreviations import (
|
||||||
ABBREVIATIONS,
|
ABBREVIATIONS,
|
||||||
DEVICE_ABBREVIATIONS,
|
DEVICE_ABBREVIATIONS,
|
||||||
)
|
)
|
||||||
|
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
|
||||||
from homeassistant.components.mqtt.discovery import (
|
from homeassistant.components.mqtt.discovery import (
|
||||||
MQTT_DISCOVERY_DONE,
|
MQTT_DISCOVERY_DONE,
|
||||||
MQTT_DISCOVERY_NEW,
|
MQTT_DISCOVERY_NEW,
|
||||||
|
@ -73,13 +74,10 @@ async def test_subscribing_config_topic(
|
||||||
discovery_topic = "homeassistant"
|
discovery_topic = "homeassistant"
|
||||||
await async_start(hass, discovery_topic, entry)
|
await async_start(hass, discovery_topic, entry)
|
||||||
|
|
||||||
call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1]
|
topics = [call[1][0] for call in mqtt_mock.async_subscribe.mock_calls]
|
||||||
assert call_args1[2] == 0
|
for component in SUPPORTED_COMPONENTS:
|
||||||
call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1]
|
assert f"{discovery_topic}/{component}/+/config" in topics
|
||||||
assert call_args2[2] == 0
|
assert f"{discovery_topic}/{component}/+/+/config" in topics
|
||||||
topics = [call_args1[0], call_args2[0]]
|
|
||||||
assert discovery_topic + "/+/+/config" in topics
|
|
||||||
assert discovery_topic + "/+/+/+/config" in topics
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -198,8 +196,6 @@ async def test_only_valid_components(
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert f"Integration {invalid_component} is not supported" in caplog.text
|
|
||||||
|
|
||||||
assert not mock_dispatcher_send.called
|
assert not mock_dispatcher_send.called
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue