diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 14f92e29cdf..397e530cb4b 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -41,7 +41,6 @@ from homeassistant.helpers.reload import async_integration_yaml_config from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_integration -from homeassistant.util.async_ import create_eager_task # Loading the config flow file will register the flow from . import debug_info, discovery @@ -100,9 +99,11 @@ from .models import ( # noqa: F401 ) from .util import ( # noqa: F401 async_create_certificate_temp_files, + async_forward_entry_setup_and_setup_discovery, async_wait_for_mqtt_client, get_mqtt_data, mqtt_config_entry_enabled, + platforms_from_config, valid_publish_topic, valid_qos_schema, valid_subscribe_topic, @@ -411,13 +412,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: translation_placeholders=ex.translation_placeholders, ) from ex + new_config: list[ConfigType] = config_yaml.get(DOMAIN, []) + platforms_used = platforms_from_config(new_config) + new_platforms = platforms_used - mqtt_data.platforms_loaded + await async_forward_entry_setup_and_setup_discovery( + hass, entry, new_platforms + ) # Check the schema before continuing reload await async_check_config_schema(hass, config_yaml) # Remove repair issues _async_remove_mqtt_issues(hass, mqtt_data) - mqtt_data.config = config_yaml.get(DOMAIN, {}) + mqtt_data.config = new_config # Reload the modern yaml platforms mqtt_platforms = async_get_platforms(hass, DOMAIN) @@ -439,36 +446,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config) - async def async_forward_entry_setup_and_setup_discovery( - config_entry: ConfigEntry, - conf: ConfigType, - ) -> None: - """Forward the config entry setup to the platforms and set up discovery.""" - # Local import to avoid circular dependencies - # pylint: disable-next=import-outside-toplevel - from . import device_automation, tag - - # Forward the entry setup to the MQTT platforms - await asyncio.gather( - *( - create_eager_task( - device_automation.async_setup_entry(hass, config_entry) - ), - create_eager_task(tag.async_setup_entry(hass, config_entry)), - create_eager_task( - hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - ), - ) + platforms_used = platforms_from_config(mqtt_data.config) + await async_forward_entry_setup_and_setup_discovery(hass, entry, platforms_used) + # Setup reload service after all platforms have loaded + await async_setup_reload_service() + # Setup discovery + if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY): + await discovery.async_start( + hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry ) - # Setup discovery - if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY): - await discovery.async_start( - hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry - ) - # Setup reload service after all platforms have loaded - await async_setup_reload_service() - - await async_forward_entry_setup_and_setup_discovery(entry, conf) return True @@ -605,9 +591,10 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await asyncio.gather( *( hass.config_entries.async_forward_entry_unload(entry, component) - for component in PLATFORMS + for component in mqtt_data.platforms_loaded ) ) + mqtt_data.platforms_loaded = set() await asyncio.sleep(0) # Unsubscribe reload dispatchers while reload_dispatchers := mqtt_data.reload_dispatchers: diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 5fa1b6297d7..180f3524dee 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -40,7 +40,7 @@ from .const import ( DOMAIN, ) from .models import MqttOriginInfo, ReceiveMessage -from .util import get_mqtt_data +from .util import async_forward_entry_setup_and_setup_discovery, get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -81,6 +81,7 @@ SUPPORTED_COMPONENTS = { MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" +MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component" MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" TOPIC_BASE = "~" @@ -141,7 +142,33 @@ async def async_start( # noqa: C901 ) -> None: """Start MQTT Discovery.""" mqtt_data = get_mqtt_data(hass) - mqtt_integrations = {} + platform_setup_lock: dict[str, asyncio.Lock] = {} + + async def _async_jit_component_setup( + discovery_payload: MQTTDiscoveryPayload, + ) -> None: + """Perform just in time components set up.""" + discovery_hash = discovery_payload.discovery_data[ATTR_DISCOVERY_HASH] + component, discovery_id = discovery_hash + platform_setup_lock.setdefault(component, asyncio.Lock()) + async with platform_setup_lock[component]: + if component not in mqtt_data.platforms_loaded: + await async_forward_entry_setup_and_setup_discovery( + hass, config_entry, {component} + ) + # Add component + message = f"Found new component: {component} {discovery_id}" + async_log_discovery_origin_info(message, discovery_payload) + mqtt_data.discovery_already_discovered.add(discovery_hash) + async_dispatcher_send( + hass, MQTT_DISCOVERY_NEW.format(component, "mqtt"), discovery_payload + ) + + mqtt_data.reload_dispatchers.append( + async_dispatcher_connect( + hass, MQTT_DISCOVERY_NEW_COMPONENT, _async_jit_component_setup + ) + ) @callback def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901 @@ -304,7 +331,10 @@ async def async_start( # noqa: C901 "pending": deque([]), } - if already_discovered: + if component not in mqtt_data.platforms_loaded and payload: + # Load component first + async_dispatcher_send(hass, MQTT_DISCOVERY_NEW_COMPONENT, payload) + elif already_discovered: # Dispatch update message = f"Component has already been discovered: {component} {discovery_id}, sending update" async_log_discovery_origin_info(message, payload) diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index bfbf9e011eb..9c961a3b543 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, TypedDict import voluptuous as vol -from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME +from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME, Platform from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import ServiceValidationError, TemplateError from homeassistant.helpers import template @@ -409,6 +409,7 @@ class MqttData: discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list) integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict) last_discovery: float = 0.0 + platforms_loaded: set[Platform | str] = field(default_factory=set) reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) reload_handlers: dict[str, CALLBACK_TYPE] = field(default_factory=dict) reload_schema: dict[str, vol.Schema] = field(default_factory=dict) diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index fb47bbfc667..53462f87321 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -10,10 +10,12 @@ from typing import Any import voluptuous as vol -from homeassistant.config_entries import ConfigEntryState +from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers.typing import ConfigType +from homeassistant.util.async_ import create_eager_task from .const import ( ATTR_PAYLOAD, @@ -39,6 +41,49 @@ TEMP_DIR_NAME = f"home-assistant-{DOMAIN}" _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) +def platforms_from_config(config: list[ConfigType]) -> set[Platform | str]: + """Return the platforms to be set up.""" + return {key for platform in config for key in platform} + + +async def async_forward_entry_setup_and_setup_discovery( + hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str] +) -> None: + """Forward the config entry setup to the platforms and set up discovery.""" + mqtt_data = get_mqtt_data(hass) + platforms_loaded = mqtt_data.platforms_loaded + new_platforms: set[Platform | str] = platforms - platforms_loaded + platforms_loaded.update(new_platforms) + tasks: list[asyncio.Task] = [] + if "device_automation" in new_platforms: + # Local import to avoid circular dependencies + # pylint: disable-next=import-outside-toplevel + from . import device_automation + + new_platforms.remove("device_automation") + tasks.append( + create_eager_task(device_automation.async_setup_entry(hass, config_entry)) + ) + if "tag" in new_platforms: + # Local import to avoid circular dependencies + # pylint: disable-next=import-outside-toplevel + from . import tag + + new_platforms.remove("tag") + tasks.append(create_eager_task(tag.async_setup_entry(hass, config_entry))) + if new_platforms: + tasks.append( + create_eager_task( + hass.config_entries.async_forward_entry_setups( + config_entry, new_platforms + ) + ) + ) + if not tasks: + return + await asyncio.gather(*tasks) + + def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: """Return true when the MQTT config entry is enabled.""" if not bool(hass.config_entries.async_entries(DOMAIN)): diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index b10031c75f8..60a909395f1 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -2,6 +2,7 @@ import asyncio from collections.abc import Generator +from copy import deepcopy from datetime import datetime, timedelta from functools import partial import json @@ -3546,7 +3547,6 @@ async def test_subscribe_connection_status( assert mqtt_connected_calls_async[1] is False -@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.LIGHT]) async def test_unload_config_entry( hass: HomeAssistant, mqtt_mock: MqttMockHAClient, @@ -3563,6 +3563,7 @@ async def test_unload_config_entry( # Publish just before unloading to test await cleanup mqtt_client_mock.reset_mock() mqtt.publish(hass, "just_in_time", "published", qos=0, retain=False) + await hass.async_block_till_done() assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id) new_mqtt_config_entry = mqtt_config_entry @@ -4046,3 +4047,127 @@ async def test_reload_with_empty_config( await hass.async_block_till_done() assert hass.states.get("sensor.test") is None + + +@pytest.mark.parametrize( + "hass_config", + [ + { + "mqtt": [ + { + "sensor": { + "name": "test", + "state_topic": "test-topic", + } + }, + ] + } + ], +) +async def test_reload_with_new_platform_config( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, +) -> None: + """Test reloading yaml with new platform config.""" + await mqtt_mock_entry() + assert hass.states.get("sensor.test") is not None + assert hass.states.get("binary_sensor.test") is None + + new_config = { + "mqtt": [ + { + "sensor": { + "name": "test", + "state_topic": "test-topic1", + }, + "binary_sensor": { + "name": "test", + "state_topic": "test-topic2", + }, + }, + ] + } + + # Reload with an new platform config and assert again + with patch("homeassistant.config.load_yaml_config_file", return_value=new_config): + await hass.services.async_call( + "mqtt", + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + assert hass.states.get("sensor.test") is not None + assert hass.states.get("binary_sensor.test") is not None + + +async def test_multi_platform_discovery( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mqtt_mock_entry: MqttMockHAClientGenerator, +) -> None: + """Test setting up multiple platforms simultaneous.""" + await mqtt_mock_entry() + entity_configs = { + "alarm_control_panel": { + "name": "test", + "state_topic": "alarm/state", + "command_topic": "alarm/command", + }, + "button": {"name": "test", "command_topic": "test-topic"}, + "camera": {"name": "test", "topic": "test_topic"}, + "cover": {"name": "test", "state_topic": "test-topic"}, + "device_tracker": { + "name": "test", + "state_topic": "test-topic", + }, + "fan": { + "name": "test", + "state_topic": "state-topic", + "command_topic": "command-topic", + }, + "sensor": {"name": "test", "state_topic": "test-topic"}, + "switch": {"name": "test", "command_topic": "test-topic"}, + "select": { + "name": "test", + "command_topic": "test-topic", + "options": ["milk", "beer"], + }, + } + non_entity_configs = { + "tag": { + "device": {"identifiers": ["tag_0AFFD2"]}, + "topic": "foobar/tag_scanned", + }, + "device_automation": { + "automation_type": "trigger", + "device": {"identifiers": ["device_automation_0AFFD2"]}, + "payload": "short_press", + "topic": "foobar/triggers/button1", + "type": "button_short_press", + "subtype": "button_1", + }, + } + for platform, config in entity_configs.items(): + for set_number in range(0, 2): + set_config = deepcopy(config) + set_config["name"] = f"test_{set_number}" + topic = f"homeassistant/{platform}/bla_{set_number}/config" + async_fire_mqtt_message(hass, topic, json.dumps(set_config)) + for platform, config in non_entity_configs.items(): + topic = f"homeassistant/{platform}/bla/config" + async_fire_mqtt_message(hass, topic, json.dumps(config)) + await hass.async_block_till_done() + for set_number in range(0, 2): + for platform in entity_configs: + entity_id = f"{platform}.test_{set_number}" + state = hass.states.get(entity_id) + assert state is not None + for platform in non_entity_configs: + assert ( + device_registry.async_get_device( + identifiers={("mqtt", f"{platform}_0AFFD2")} + ) + is not None + )