From 7e8d52e5a321e30322fb8c006e936d5dc9bfbb83 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 30 Mar 2022 05:26:11 +0200 Subject: [PATCH] Fix race when handling rapid succession of MQTT discovery messages (#68785) Co-authored-by: jbouwh --- .../components/mqtt/binary_sensor.py | 5 +- homeassistant/components/mqtt/mixins.py | 32 +++++++++--- homeassistant/components/mqtt/scene.py | 1 + homeassistant/components/mqtt/sensor.py | 5 +- tests/components/mqtt/test_config_flow.py | 20 +++++++ tests/components/mqtt/test_discovery.py | 52 +++++++++++++++++++ tests/conftest.py | 2 + 7 files changed, 103 insertions(+), 14 deletions(-) diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index 40d5c876c11..5d0da99d786 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -118,16 +118,15 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): MqttEntity.__init__(self, hass, config, config_entry, discovery_data) - async def async_added_to_hass(self) -> None: + async def mqtt_async_added_to_hass(self) -> None: """Restore state for entities with expire_after set.""" - await super().async_added_to_hass() if ( (expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None and expire_after > 0 and (last_state := await self.async_get_last_state()) is not None and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE] # We might have set up a trigger already after subscribing from - # super().async_added_to_hass(), then we should not restore state + # MqttEntity.async_added_to_hass(), then we should not restore state and not self._expiration_trigger ): expiration_at = last_state.last_changed + timedelta(seconds=expire_after) diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 19c44ba8cdd..659d0debe31 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections.abc import Callable import json import logging -from typing import Any, Protocol +from typing import Any, Protocol, final import voluptuous as vol @@ -572,9 +572,7 @@ class MqttDiscoveryUpdate(Entity): else: # Non-empty, unchanged payload: Ignore to avoid changing states _LOGGER.info("Ignoring unchanged update for: %s", self.entity_id) - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None - ) + self.async_send_discovery_done() if discovery_hash: debug_info.add_entity_discovery_data( @@ -587,9 +585,18 @@ class MqttDiscoveryUpdate(Entity): MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_callback, ) - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None - ) + + @callback + def async_send_discovery_done(self) -> None: + """Acknowledge a discovery message has been handled.""" + discovery_hash = ( + self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None + ) + if not discovery_hash: + return + async_dispatcher_send( + self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None + ) async def async_removed_from_registry(self) -> None: """Clear retained discovery topic in broker.""" @@ -723,11 +730,20 @@ class MqttEntity( self.hass, self, self._config, self._entity_id_format ) + @final async def async_added_to_hass(self): - """Subscribe mqtt events.""" + """Subscribe to MQTT events.""" await super().async_added_to_hass() self._prepare_subscribe_topics() await self._subscribe_topics() + await self.mqtt_async_added_to_hass() + self.async_send_discovery_done() + + async def mqtt_async_added_to_hass(self): + """Call before the discovery message is acknowledged. + + To be extended by subclasses. + """ async def discovery_update(self, discovery_payload): """Handle updated discovery message.""" diff --git a/homeassistant/components/mqtt/scene.py b/homeassistant/components/mqtt/scene.py index c44ea1dca53..5584bfb27db 100644 --- a/homeassistant/components/mqtt/scene.py +++ b/homeassistant/components/mqtt/scene.py @@ -110,6 +110,7 @@ class MqttScene( async def async_added_to_hass(self): """Subscribe to MQTT events.""" await super().async_added_to_hass() + self.async_send_discovery_done() async def discovery_update(self, discovery_payload): """Handle updated discovery message.""" diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index c24535ebd1f..a13f58f95ea 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -163,9 +163,8 @@ class MqttSensor(MqttEntity, RestoreSensor): MqttEntity.__init__(self, hass, config, config_entry, discovery_data) - async def async_added_to_hass(self) -> None: + async def mqtt_async_added_to_hass(self) -> None: """Restore state for entities with expire_after set.""" - await super().async_added_to_hass() if ( (expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None and expire_after > 0 @@ -174,7 +173,7 @@ class MqttSensor(MqttEntity, RestoreSensor): and (last_sensor_data := await self.async_get_last_sensor_data()) is not None # We might have set up a trigger already after subscribing from - # super().async_added_to_hass(), then we should not restore state + # MqttEntity.async_added_to_hass(), then we should not restore state and not self._expiration_trigger ): expiration_at = last_state.last_changed + timedelta(seconds=expire_after) diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index 88c6137bf94..f0e02ad8a3a 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -34,12 +34,32 @@ def mock_try_connection(): def mock_try_connection_success(): """Mock the try connection method with success.""" + _mid = 1 + + def get_mid(): + nonlocal _mid + _mid += 1 + return _mid + def loop_start(): """Simulate connect on loop start.""" mock_client().on_connect(mock_client, None, None, 0) + def _subscribe(topic, qos=0): + mid = get_mid() + mock_client().on_subscribe(mock_client, 0, mid) + return (0, mid) + + def _unsubscribe(topic): + mid = get_mid() + mock_client().on_unsubscribe(mock_client, 0, mid) + return (0, mid) + with patch("paho.mqtt.client.Client") as mock_client: mock_client().loop_start = loop_start + mock_client().subscribe = _subscribe + mock_client().unsubscribe = _unsubscribe + yield mock_client() diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index d0c00f6b2ce..a1ef6ea477a 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -550,6 +550,58 @@ async def test_rapid_rediscover_unique(hass, mqtt_mock, caplog): assert events[3].data["old_state"] is None +async def test_rapid_reconfigure(hass, mqtt_mock, caplog): + """Test immediate reconfigure of added component.""" + + events = [] + + @ha.callback + def callback(event): + """Verify event got called.""" + events.append(event) + + hass.bus.async_listen(EVENT_STATE_CHANGED, callback) + + # Discovery immediately followed by reconfig + async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "") + async_fire_mqtt_message( + hass, + "homeassistant/binary_sensor/bla/config", + '{ "name": "Beer", "state_topic": "test-topic1" }', + ) + async_fire_mqtt_message( + hass, + "homeassistant/binary_sensor/bla/config", + '{ "name": "Milk", "state_topic": "test-topic2" }', + ) + async_fire_mqtt_message( + hass, + "homeassistant/binary_sensor/bla/config", + '{ "name": "Wine", "state_topic": "test-topic3" }', + ) + await hass.async_block_till_done() + + assert len(hass.states.async_entity_ids("binary_sensor")) == 1 + state = hass.states.get("binary_sensor.beer") + assert state is not None + + assert len(events) == 3 + # Add the entity + assert events[0].data["entity_id"] == "binary_sensor.beer" + assert events[0].data["old_state"] is None + assert events[0].data["new_state"].attributes["friendly_name"] == "Beer" + # Update the entity + assert events[1].data["entity_id"] == "binary_sensor.beer" + assert events[1].data["new_state"] is not None + assert events[1].data["old_state"] is not None + assert events[1].data["new_state"].attributes["friendly_name"] == "Milk" + # Update the entity + assert events[2].data["entity_id"] == "binary_sensor.beer" + assert events[2].data["new_state"] is not None + assert events[2].data["old_state"] is not None + assert events[2].data["new_state"].attributes["friendly_name"] == "Wine" + + async def test_duplicate_removal(hass, mqtt_mock, caplog): """Test for a non duplicate component.""" async_fire_mqtt_message( diff --git a/tests/conftest.py b/tests/conftest.py index 92f19f087a6..ae982443ea4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -618,6 +618,8 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config): ) mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics mqtt_component_mock._mqttc = mqtt_client_mock + # connected set to True to get a more realistics behavior when subscribing + hass.data["mqtt"].connected = True hass.data["mqtt"] = mqtt_component_mock component = hass.data["mqtt"]