Allow Just-in-Time platform setup for mqtt (#112720)

* Allow Just-in-Time platform setup for mqtt

* Only forward the setup of new platforms

* Fix new  platforms being setup at reload + test

* Revert not related changes

* Remove unused partial

* Address comments, only import plaforms if needed

* Apply suggestions from code review

* Add multipl platform discovery test

* Improve test

* Use a lock per platform
This commit is contained in:
Jan Bouwhuis 2024-03-09 21:55:00 +01:00 committed by GitHub
parent d0d1af8991
commit 3b0ea52167
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 226 additions and 38 deletions

View file

@ -41,7 +41,6 @@ from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.util.async_ import create_eager_task
# Loading the config flow file will register the flow # Loading the config flow file will register the flow
from . import debug_info, discovery from . import debug_info, discovery
@ -100,9 +99,11 @@ from .models import ( # noqa: F401
) )
from .util import ( # noqa: F401 from .util import ( # noqa: F401
async_create_certificate_temp_files, async_create_certificate_temp_files,
async_forward_entry_setup_and_setup_discovery,
async_wait_for_mqtt_client, async_wait_for_mqtt_client,
get_mqtt_data, get_mqtt_data,
mqtt_config_entry_enabled, mqtt_config_entry_enabled,
platforms_from_config,
valid_publish_topic, valid_publish_topic,
valid_qos_schema, valid_qos_schema,
valid_subscribe_topic, valid_subscribe_topic,
@ -411,13 +412,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
translation_placeholders=ex.translation_placeholders, translation_placeholders=ex.translation_placeholders,
) from ex ) from ex
new_config: list[ConfigType] = config_yaml.get(DOMAIN, [])
platforms_used = platforms_from_config(new_config)
new_platforms = platforms_used - mqtt_data.platforms_loaded
await async_forward_entry_setup_and_setup_discovery(
hass, entry, new_platforms
)
# Check the schema before continuing reload # Check the schema before continuing reload
await async_check_config_schema(hass, config_yaml) await async_check_config_schema(hass, config_yaml)
# Remove repair issues # Remove repair issues
_async_remove_mqtt_issues(hass, mqtt_data) _async_remove_mqtt_issues(hass, mqtt_data)
mqtt_data.config = config_yaml.get(DOMAIN, {}) mqtt_data.config = new_config
# Reload the modern yaml platforms # Reload the modern yaml platforms
mqtt_platforms = async_get_platforms(hass, DOMAIN) mqtt_platforms = async_get_platforms(hass, DOMAIN)
@ -439,36 +446,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config) async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config)
async def async_forward_entry_setup_and_setup_discovery( platforms_used = platforms_from_config(mqtt_data.config)
config_entry: ConfigEntry, await async_forward_entry_setup_and_setup_discovery(hass, entry, platforms_used)
conf: ConfigType, # Setup reload service after all platforms have loaded
) -> None: await async_setup_reload_service()
"""Forward the config entry setup to the platforms and set up discovery.""" # Setup discovery
# Local import to avoid circular dependencies if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY):
# pylint: disable-next=import-outside-toplevel await discovery.async_start(
from . import device_automation, tag hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry
# Forward the entry setup to the MQTT platforms
await asyncio.gather(
*(
create_eager_task(
device_automation.async_setup_entry(hass, config_entry)
),
create_eager_task(tag.async_setup_entry(hass, config_entry)),
create_eager_task(
hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
),
)
) )
# Setup discovery
if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY):
await discovery.async_start(
hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry
)
# Setup reload service after all platforms have loaded
await async_setup_reload_service()
await async_forward_entry_setup_and_setup_discovery(entry, conf)
return True return True
@ -605,9 +591,10 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await asyncio.gather( await asyncio.gather(
*( *(
hass.config_entries.async_forward_entry_unload(entry, component) hass.config_entries.async_forward_entry_unload(entry, component)
for component in PLATFORMS for component in mqtt_data.platforms_loaded
) )
) )
mqtt_data.platforms_loaded = set()
await asyncio.sleep(0) await asyncio.sleep(0)
# Unsubscribe reload dispatchers # Unsubscribe reload dispatchers
while reload_dispatchers := mqtt_data.reload_dispatchers: while reload_dispatchers := mqtt_data.reload_dispatchers:

View file

@ -40,7 +40,7 @@ from .const import (
DOMAIN, DOMAIN,
) )
from .models import MqttOriginInfo, ReceiveMessage from .models import MqttOriginInfo, ReceiveMessage
from .util import get_mqtt_data from .util import async_forward_entry_setup_and_setup_discovery, get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -81,6 +81,7 @@ SUPPORTED_COMPONENTS = {
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component"
MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
TOPIC_BASE = "~" TOPIC_BASE = "~"
@ -141,7 +142,33 @@ async def async_start( # noqa: C901
) -> None: ) -> None:
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
mqtt_integrations = {} platform_setup_lock: dict[str, asyncio.Lock] = {}
async def _async_jit_component_setup(
discovery_payload: MQTTDiscoveryPayload,
) -> None:
"""Perform just in time components set up."""
discovery_hash = discovery_payload.discovery_data[ATTR_DISCOVERY_HASH]
component, discovery_id = discovery_hash
platform_setup_lock.setdefault(component, asyncio.Lock())
async with platform_setup_lock[component]:
if component not in mqtt_data.platforms_loaded:
await async_forward_entry_setup_and_setup_discovery(
hass, config_entry, {component}
)
# Add component
message = f"Found new component: {component} {discovery_id}"
async_log_discovery_origin_info(message, discovery_payload)
mqtt_data.discovery_already_discovered.add(discovery_hash)
async_dispatcher_send(
hass, MQTT_DISCOVERY_NEW.format(component, "mqtt"), discovery_payload
)
mqtt_data.reload_dispatchers.append(
async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW_COMPONENT, _async_jit_component_setup
)
)
@callback @callback
def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901 def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901
@ -304,7 +331,10 @@ async def async_start( # noqa: C901
"pending": deque([]), "pending": deque([]),
} }
if already_discovered: if component not in mqtt_data.platforms_loaded and payload:
# Load component first
async_dispatcher_send(hass, MQTT_DISCOVERY_NEW_COMPONENT, payload)
elif already_discovered:
# Dispatch update # Dispatch update
message = f"Component has already been discovered: {component} {discovery_id}, sending update" message = f"Component has already been discovered: {component} {discovery_id}, sending update"
async_log_discovery_origin_info(message, payload) async_log_discovery_origin_info(message, payload)

View file

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, TypedDict
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME, Platform
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import ServiceValidationError, TemplateError from homeassistant.exceptions import ServiceValidationError, TemplateError
from homeassistant.helpers import template from homeassistant.helpers import template
@ -409,6 +409,7 @@ class MqttData:
discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list) discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list)
integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict) integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
last_discovery: float = 0.0 last_discovery: float = 0.0
platforms_loaded: set[Platform | str] = field(default_factory=set)
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_handlers: dict[str, CALLBACK_TYPE] = field(default_factory=dict) reload_handlers: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
reload_schema: dict[str, vol.Schema] = field(default_factory=dict) reload_schema: dict[str, vol.Schema] = field(default_factory=dict)

View file

@ -10,10 +10,12 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util.async_ import create_eager_task
from .const import ( from .const import (
ATTR_PAYLOAD, ATTR_PAYLOAD,
@ -39,6 +41,49 @@ TEMP_DIR_NAME = f"home-assistant-{DOMAIN}"
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
def platforms_from_config(config: list[ConfigType]) -> set[Platform | str]:
"""Return the platforms to be set up."""
return {key for platform in config for key in platform}
async def async_forward_entry_setup_and_setup_discovery(
hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str]
) -> None:
"""Forward the config entry setup to the platforms and set up discovery."""
mqtt_data = get_mqtt_data(hass)
platforms_loaded = mqtt_data.platforms_loaded
new_platforms: set[Platform | str] = platforms - platforms_loaded
platforms_loaded.update(new_platforms)
tasks: list[asyncio.Task] = []
if "device_automation" in new_platforms:
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from . import device_automation
new_platforms.remove("device_automation")
tasks.append(
create_eager_task(device_automation.async_setup_entry(hass, config_entry))
)
if "tag" in new_platforms:
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from . import tag
new_platforms.remove("tag")
tasks.append(create_eager_task(tag.async_setup_entry(hass, config_entry)))
if new_platforms:
tasks.append(
create_eager_task(
hass.config_entries.async_forward_entry_setups(
config_entry, new_platforms
)
)
)
if not tasks:
return
await asyncio.gather(*tasks)
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
"""Return true when the MQTT config entry is enabled.""" """Return true when the MQTT config entry is enabled."""
if not bool(hass.config_entries.async_entries(DOMAIN)): if not bool(hass.config_entries.async_entries(DOMAIN)):

View file

@ -2,6 +2,7 @@
import asyncio import asyncio
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import json import json
@ -3546,7 +3547,6 @@ async def test_subscribe_connection_status(
assert mqtt_connected_calls_async[1] is False assert mqtt_connected_calls_async[1] is False
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.LIGHT])
async def test_unload_config_entry( async def test_unload_config_entry(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock: MqttMockHAClient, mqtt_mock: MqttMockHAClient,
@ -3563,6 +3563,7 @@ async def test_unload_config_entry(
# Publish just before unloading to test await cleanup # Publish just before unloading to test await cleanup
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
mqtt.publish(hass, "just_in_time", "published", qos=0, retain=False) mqtt.publish(hass, "just_in_time", "published", qos=0, retain=False)
await hass.async_block_till_done()
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id) assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
new_mqtt_config_entry = mqtt_config_entry new_mqtt_config_entry = mqtt_config_entry
@ -4046,3 +4047,127 @@ async def test_reload_with_empty_config(
await hass.async_block_till_done() await hass.async_block_till_done()
assert hass.states.get("sensor.test") is None assert hass.states.get("sensor.test") is None
@pytest.mark.parametrize(
"hass_config",
[
{
"mqtt": [
{
"sensor": {
"name": "test",
"state_topic": "test-topic",
}
},
]
}
],
)
async def test_reload_with_new_platform_config(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test reloading yaml with new platform config."""
await mqtt_mock_entry()
assert hass.states.get("sensor.test") is not None
assert hass.states.get("binary_sensor.test") is None
new_config = {
"mqtt": [
{
"sensor": {
"name": "test",
"state_topic": "test-topic1",
},
"binary_sensor": {
"name": "test",
"state_topic": "test-topic2",
},
},
]
}
# Reload with an new platform config and assert again
with patch("homeassistant.config.load_yaml_config_file", return_value=new_config):
await hass.services.async_call(
"mqtt",
SERVICE_RELOAD,
{},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("sensor.test") is not None
assert hass.states.get("binary_sensor.test") is not None
async def test_multi_platform_discovery(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test setting up multiple platforms simultaneous."""
await mqtt_mock_entry()
entity_configs = {
"alarm_control_panel": {
"name": "test",
"state_topic": "alarm/state",
"command_topic": "alarm/command",
},
"button": {"name": "test", "command_topic": "test-topic"},
"camera": {"name": "test", "topic": "test_topic"},
"cover": {"name": "test", "state_topic": "test-topic"},
"device_tracker": {
"name": "test",
"state_topic": "test-topic",
},
"fan": {
"name": "test",
"state_topic": "state-topic",
"command_topic": "command-topic",
},
"sensor": {"name": "test", "state_topic": "test-topic"},
"switch": {"name": "test", "command_topic": "test-topic"},
"select": {
"name": "test",
"command_topic": "test-topic",
"options": ["milk", "beer"],
},
}
non_entity_configs = {
"tag": {
"device": {"identifiers": ["tag_0AFFD2"]},
"topic": "foobar/tag_scanned",
},
"device_automation": {
"automation_type": "trigger",
"device": {"identifiers": ["device_automation_0AFFD2"]},
"payload": "short_press",
"topic": "foobar/triggers/button1",
"type": "button_short_press",
"subtype": "button_1",
},
}
for platform, config in entity_configs.items():
for set_number in range(0, 2):
set_config = deepcopy(config)
set_config["name"] = f"test_{set_number}"
topic = f"homeassistant/{platform}/bla_{set_number}/config"
async_fire_mqtt_message(hass, topic, json.dumps(set_config))
for platform, config in non_entity_configs.items():
topic = f"homeassistant/{platform}/bla/config"
async_fire_mqtt_message(hass, topic, json.dumps(config))
await hass.async_block_till_done()
for set_number in range(0, 2):
for platform in entity_configs:
entity_id = f"{platform}.test_{set_number}"
state = hass.states.get(entity_id)
assert state is not None
for platform in non_entity_configs:
assert (
device_registry.async_get_device(
identifiers={("mqtt", f"{platform}_0AFFD2")}
)
is not None
)