Fix mqtt platform setup race (#112888)

This commit is contained in:
Jan Bouwhuis 2024-03-10 19:36:17 +01:00 committed by GitHub
parent eb81bf1d49
commit c608d1cb85
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 19 additions and 12 deletions

View file

@ -53,14 +53,12 @@ async def async_forward_entry_setup_and_setup_discovery(
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
platforms_loaded = mqtt_data.platforms_loaded platforms_loaded = mqtt_data.platforms_loaded
new_platforms: set[Platform | str] = platforms - platforms_loaded new_platforms: set[Platform | str] = platforms - platforms_loaded
platforms_loaded.update(new_platforms)
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
if "device_automation" in new_platforms: if "device_automation" in new_platforms:
# Local import to avoid circular dependencies # Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
from . import device_automation from . import device_automation
new_platforms.remove("device_automation")
tasks.append( tasks.append(
create_eager_task(device_automation.async_setup_entry(hass, config_entry)) create_eager_task(device_automation.async_setup_entry(hass, config_entry))
) )
@ -69,19 +67,19 @@ async def async_forward_entry_setup_and_setup_discovery(
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
from . import tag from . import tag
new_platforms.remove("tag")
tasks.append(create_eager_task(tag.async_setup_entry(hass, config_entry))) tasks.append(create_eager_task(tag.async_setup_entry(hass, config_entry)))
if new_platforms: if new_entity_platforms := (new_platforms - {"tag", "device_automation"}):
tasks.append( tasks.append(
create_eager_task( create_eager_task(
hass.config_entries.async_forward_entry_setups( hass.config_entries.async_forward_entry_setups(
config_entry, new_platforms config_entry, new_entity_platforms
) )
) )
) )
if not tasks: if not tasks:
return return
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
platforms_loaded.update(new_platforms)
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:

View file

@ -936,6 +936,9 @@ async def help_test_encoding_subscribable_topics(
hass, f"homeassistant/{domain}/item3/config", json.dumps(config3) hass, f"homeassistant/{domain}/item3/config", json.dumps(config3)
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
expected_result = attribute_value or value expected_result = attribute_value or value

View file

@ -3598,14 +3598,20 @@ async def test_disabling_and_enabling_entry(
config_alarm_control_panel = '{"name": "test_new", "state_topic": "home/alarm", "command_topic": "home/alarm/set"}' config_alarm_control_panel = '{"name": "test_new", "state_topic": "home/alarm", "command_topic": "home/alarm/set"}'
config_light = '{"name": "test_new", "command_topic": "test-topic_new"}' config_light = '{"name": "test_new", "command_topic": "test-topic_new"}'
# Discovery of mqtt tag with patch(
async_fire_mqtt_message(hass, "homeassistant/tag/abc/config", config_tag) "homeassistant.components.mqtt.mixins.mqtt_config_entry_enabled",
return_value=False,
):
# Discovery of mqtt tag
async_fire_mqtt_message(hass, "homeassistant/tag/abc/config", config_tag)
# Late discovery of mqtt entities # Late discovery of mqtt entities
async_fire_mqtt_message( async_fire_mqtt_message(
hass, "homeassistant/alarm_control_panel/abc/config", config_alarm_control_panel hass,
) "homeassistant/alarm_control_panel/abc/config",
async_fire_mqtt_message(hass, "homeassistant/light/abc/config", config_light) config_alarm_control_panel,
)
async_fire_mqtt_message(hass, "homeassistant/light/abc/config", config_light)
# Disable MQTT config entry # Disable MQTT config entry
await hass.config_entries.async_set_disabled_by( await hass.config_entries.async_set_disabled_by(