Allow Just-in-Time platform setup for mqtt (#112720)

* Allow Just-in-Time platform setup for mqtt

* Only forward the setup of new platforms

* Fix new  platforms being setup at reload + test

* Revert not related changes

* Remove unused partial

* Address comments, only import plaforms if needed

* Apply suggestions from code review

* Add multipl platform discovery test

* Improve test

* Use a lock per platform
This commit is contained in:
Jan Bouwhuis 2024-03-09 21:55:00 +01:00 committed by GitHub
parent d0d1af8991
commit 3b0ea52167
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 226 additions and 38 deletions

View file

@ -41,7 +41,6 @@ from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_integration
from homeassistant.util.async_ import create_eager_task
# Loading the config flow file will register the flow
from . import debug_info, discovery
@ -100,9 +99,11 @@ from .models import ( # noqa: F401
)
from .util import ( # noqa: F401
async_create_certificate_temp_files,
async_forward_entry_setup_and_setup_discovery,
async_wait_for_mqtt_client,
get_mqtt_data,
mqtt_config_entry_enabled,
platforms_from_config,
valid_publish_topic,
valid_qos_schema,
valid_subscribe_topic,
@ -411,13 +412,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
translation_placeholders=ex.translation_placeholders,
) from ex
new_config: list[ConfigType] = config_yaml.get(DOMAIN, [])
platforms_used = platforms_from_config(new_config)
new_platforms = platforms_used - mqtt_data.platforms_loaded
await async_forward_entry_setup_and_setup_discovery(
hass, entry, new_platforms
)
# Check the schema before continuing reload
await async_check_config_schema(hass, config_yaml)
# Remove repair issues
_async_remove_mqtt_issues(hass, mqtt_data)
mqtt_data.config = config_yaml.get(DOMAIN, {})
mqtt_data.config = new_config
# Reload the modern yaml platforms
mqtt_platforms = async_get_platforms(hass, DOMAIN)
@ -439,36 +446,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config)
async def async_forward_entry_setup_and_setup_discovery(
config_entry: ConfigEntry,
conf: ConfigType,
) -> None:
"""Forward the config entry setup to the platforms and set up discovery."""
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from . import device_automation, tag
# Forward the entry setup to the MQTT platforms
await asyncio.gather(
*(
create_eager_task(
device_automation.async_setup_entry(hass, config_entry)
),
create_eager_task(tag.async_setup_entry(hass, config_entry)),
create_eager_task(
hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
),
)
platforms_used = platforms_from_config(mqtt_data.config)
await async_forward_entry_setup_and_setup_discovery(hass, entry, platforms_used)
# Setup reload service after all platforms have loaded
await async_setup_reload_service()
# Setup discovery
if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY):
await discovery.async_start(
hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry
)
# Setup discovery
if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY):
await discovery.async_start(
hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry
)
# Setup reload service after all platforms have loaded
await async_setup_reload_service()
await async_forward_entry_setup_and_setup_discovery(entry, conf)
return True
@ -605,9 +591,10 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await asyncio.gather(
*(
hass.config_entries.async_forward_entry_unload(entry, component)
for component in PLATFORMS
for component in mqtt_data.platforms_loaded
)
)
mqtt_data.platforms_loaded = set()
await asyncio.sleep(0)
# Unsubscribe reload dispatchers
while reload_dispatchers := mqtt_data.reload_dispatchers:

View file

@ -40,7 +40,7 @@ from .const import (
DOMAIN,
)
from .models import MqttOriginInfo, ReceiveMessage
from .util import get_mqtt_data
from .util import async_forward_entry_setup_and_setup_discovery, get_mqtt_data
_LOGGER = logging.getLogger(__name__)
@ -81,6 +81,7 @@ SUPPORTED_COMPONENTS = {
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component"
MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
TOPIC_BASE = "~"
@ -141,7 +142,33 @@ async def async_start( # noqa: C901
) -> None:
"""Start MQTT Discovery."""
mqtt_data = get_mqtt_data(hass)
mqtt_integrations = {}
platform_setup_lock: dict[str, asyncio.Lock] = {}
async def _async_jit_component_setup(
discovery_payload: MQTTDiscoveryPayload,
) -> None:
"""Perform just in time components set up."""
discovery_hash = discovery_payload.discovery_data[ATTR_DISCOVERY_HASH]
component, discovery_id = discovery_hash
platform_setup_lock.setdefault(component, asyncio.Lock())
async with platform_setup_lock[component]:
if component not in mqtt_data.platforms_loaded:
await async_forward_entry_setup_and_setup_discovery(
hass, config_entry, {component}
)
# Add component
message = f"Found new component: {component} {discovery_id}"
async_log_discovery_origin_info(message, discovery_payload)
mqtt_data.discovery_already_discovered.add(discovery_hash)
async_dispatcher_send(
hass, MQTT_DISCOVERY_NEW.format(component, "mqtt"), discovery_payload
)
mqtt_data.reload_dispatchers.append(
async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW_COMPONENT, _async_jit_component_setup
)
)
@callback
def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901
@ -304,7 +331,10 @@ async def async_start( # noqa: C901
"pending": deque([]),
}
if already_discovered:
if component not in mqtt_data.platforms_loaded and payload:
# Load component first
async_dispatcher_send(hass, MQTT_DISCOVERY_NEW_COMPONENT, payload)
elif already_discovered:
# Dispatch update
message = f"Component has already been discovered: {component} {discovery_id}, sending update"
async_log_discovery_origin_info(message, payload)

View file

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, TypedDict
import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME, Platform
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import ServiceValidationError, TemplateError
from homeassistant.helpers import template
@ -409,6 +409,7 @@ class MqttData:
discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list)
integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
last_discovery: float = 0.0
platforms_loaded: set[Platform | str] = field(default_factory=set)
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_handlers: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
reload_schema: dict[str, vol.Schema] = field(default_factory=dict)

View file

@ -10,10 +10,12 @@ from typing import Any
import voluptuous as vol
from homeassistant.config_entries import ConfigEntryState
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.async_ import create_eager_task
from .const import (
ATTR_PAYLOAD,
@ -39,6 +41,49 @@ TEMP_DIR_NAME = f"home-assistant-{DOMAIN}"
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
def platforms_from_config(config: list[ConfigType]) -> set[Platform | str]:
"""Return the platforms to be set up."""
return {key for platform in config for key in platform}
async def async_forward_entry_setup_and_setup_discovery(
hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str]
) -> None:
"""Forward the config entry setup to the platforms and set up discovery."""
mqtt_data = get_mqtt_data(hass)
platforms_loaded = mqtt_data.platforms_loaded
new_platforms: set[Platform | str] = platforms - platforms_loaded
platforms_loaded.update(new_platforms)
tasks: list[asyncio.Task] = []
if "device_automation" in new_platforms:
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from . import device_automation
new_platforms.remove("device_automation")
tasks.append(
create_eager_task(device_automation.async_setup_entry(hass, config_entry))
)
if "tag" in new_platforms:
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from . import tag
new_platforms.remove("tag")
tasks.append(create_eager_task(tag.async_setup_entry(hass, config_entry)))
if new_platforms:
tasks.append(
create_eager_task(
hass.config_entries.async_forward_entry_setups(
config_entry, new_platforms
)
)
)
if not tasks:
return
await asyncio.gather(*tasks)
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
"""Return true when the MQTT config entry is enabled."""
if not bool(hass.config_entries.async_entries(DOMAIN)):

View file

@ -2,6 +2,7 @@
import asyncio
from collections.abc import Generator
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
import json
@ -3546,7 +3547,6 @@ async def test_subscribe_connection_status(
assert mqtt_connected_calls_async[1] is False
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.LIGHT])
async def test_unload_config_entry(
hass: HomeAssistant,
mqtt_mock: MqttMockHAClient,
@ -3563,6 +3563,7 @@ async def test_unload_config_entry(
# Publish just before unloading to test await cleanup
mqtt_client_mock.reset_mock()
mqtt.publish(hass, "just_in_time", "published", qos=0, retain=False)
await hass.async_block_till_done()
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
new_mqtt_config_entry = mqtt_config_entry
@ -4046,3 +4047,127 @@ async def test_reload_with_empty_config(
await hass.async_block_till_done()
assert hass.states.get("sensor.test") is None
@pytest.mark.parametrize(
"hass_config",
[
{
"mqtt": [
{
"sensor": {
"name": "test",
"state_topic": "test-topic",
}
},
]
}
],
)
async def test_reload_with_new_platform_config(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test reloading yaml with new platform config."""
await mqtt_mock_entry()
assert hass.states.get("sensor.test") is not None
assert hass.states.get("binary_sensor.test") is None
new_config = {
"mqtt": [
{
"sensor": {
"name": "test",
"state_topic": "test-topic1",
},
"binary_sensor": {
"name": "test",
"state_topic": "test-topic2",
},
},
]
}
# Reload with an new platform config and assert again
with patch("homeassistant.config.load_yaml_config_file", return_value=new_config):
await hass.services.async_call(
"mqtt",
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("sensor.test") is not None
assert hass.states.get("binary_sensor.test") is not None
async def test_multi_platform_discovery(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test setting up multiple platforms simultaneous."""
await mqtt_mock_entry()
entity_configs = {
"alarm_control_panel": {
"name": "test",
"state_topic": "alarm/state",
"command_topic": "alarm/command",
},
"button": {"name": "test", "command_topic": "test-topic"},
"camera": {"name": "test", "topic": "test_topic"},
"cover": {"name": "test", "state_topic": "test-topic"},
"device_tracker": {
"name": "test",
"state_topic": "test-topic",
},
"fan": {
"name": "test",
"state_topic": "state-topic",
"command_topic": "command-topic",
},
"sensor": {"name": "test", "state_topic": "test-topic"},
"switch": {"name": "test", "command_topic": "test-topic"},
"select": {
"name": "test",
"command_topic": "test-topic",
"options": ["milk", "beer"],
},
}
non_entity_configs = {
"tag": {
"device": {"identifiers": ["tag_0AFFD2"]},
"topic": "foobar/tag_scanned",
},
"device_automation": {
"automation_type": "trigger",
"device": {"identifiers": ["device_automation_0AFFD2"]},
"payload": "short_press",
"topic": "foobar/triggers/button1",
"type": "button_short_press",
"subtype": "button_1",
},
}
for platform, config in entity_configs.items():
for set_number in range(0, 2):
set_config = deepcopy(config)
set_config["name"] = f"test_{set_number}"
topic = f"homeassistant/{platform}/bla_{set_number}/config"
async_fire_mqtt_message(hass, topic, json.dumps(set_config))
for platform, config in non_entity_configs.items():
topic = f"homeassistant/{platform}/bla/config"
async_fire_mqtt_message(hass, topic, json.dumps(config))
await hass.async_block_till_done()
for set_number in range(0, 2):
for platform in entity_configs:
entity_id = f"{platform}.test_{set_number}"
state = hass.states.get(entity_id)
assert state is not None
for platform in non_entity_configs:
assert (
device_registry.async_get_device(
identifiers={("mqtt", f"{platform}_0AFFD2")}
)
is not None
)