Fix race when handling rapid succession of MQTT discovery messages (#68785)

Co-authored-by: jbouwh <jan@jbsoft.nl>
This commit is contained in:
Erik Montnemery 2022-03-30 05:26:11 +02:00 committed by GitHub
parent 3d378449e8
commit 7e8d52e5a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 14 deletions

View file

@ -118,16 +118,15 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
async def async_added_to_hass(self) -> None: async def mqtt_async_added_to_hass(self) -> None:
"""Restore state for entities with expire_after set.""" """Restore state for entities with expire_after set."""
await super().async_added_to_hass()
if ( if (
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None (expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
and expire_after > 0 and expire_after > 0
and (last_state := await self.async_get_last_state()) is not None and (last_state := await self.async_get_last_state()) is not None
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE] and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
# We might have set up a trigger already after subscribing from # We might have set up a trigger already after subscribing from
# super().async_added_to_hass(), then we should not restore state # MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger and not self._expiration_trigger
): ):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after) expiration_at = last_state.last_changed + timedelta(seconds=expire_after)

View file

@ -5,7 +5,7 @@ from abc import abstractmethod
from collections.abc import Callable from collections.abc import Callable
import json import json
import logging import logging
from typing import Any, Protocol from typing import Any, Protocol, final
import voluptuous as vol import voluptuous as vol
@ -572,9 +572,7 @@ class MqttDiscoveryUpdate(Entity):
else: else:
# Non-empty, unchanged payload: Ignore to avoid changing states # Non-empty, unchanged payload: Ignore to avoid changing states
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id) _LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
async_dispatcher_send( self.async_send_discovery_done()
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
if discovery_hash: if discovery_hash:
debug_info.add_entity_discovery_data( debug_info.add_entity_discovery_data(
@ -587,9 +585,18 @@ class MqttDiscoveryUpdate(Entity):
MQTT_DISCOVERY_UPDATED.format(discovery_hash), MQTT_DISCOVERY_UPDATED.format(discovery_hash),
discovery_callback, discovery_callback,
) )
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None @callback
) def async_send_discovery_done(self) -> None:
"""Acknowledge a discovery message has been handled."""
discovery_hash = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
)
if not discovery_hash:
return
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
async def async_removed_from_registry(self) -> None: async def async_removed_from_registry(self) -> None:
"""Clear retained discovery topic in broker.""" """Clear retained discovery topic in broker."""
@ -723,11 +730,20 @@ class MqttEntity(
self.hass, self, self._config, self._entity_id_format self.hass, self, self._config, self._entity_id_format
) )
@final
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe mqtt events.""" """Subscribe to MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._prepare_subscribe_topics() self._prepare_subscribe_topics()
await self._subscribe_topics() await self._subscribe_topics()
await self.mqtt_async_added_to_hass()
self.async_send_discovery_done()
async def mqtt_async_added_to_hass(self):
"""Call before the discovery message is acknowledged.
To be extended by subclasses.
"""
async def discovery_update(self, discovery_payload): async def discovery_update(self, discovery_payload):
"""Handle updated discovery message.""" """Handle updated discovery message."""

View file

@ -110,6 +110,7 @@ class MqttScene(
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe to MQTT events.""" """Subscribe to MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self.async_send_discovery_done()
async def discovery_update(self, discovery_payload): async def discovery_update(self, discovery_payload):
"""Handle updated discovery message.""" """Handle updated discovery message."""

View file

@ -163,9 +163,8 @@ class MqttSensor(MqttEntity, RestoreSensor):
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
async def async_added_to_hass(self) -> None: async def mqtt_async_added_to_hass(self) -> None:
"""Restore state for entities with expire_after set.""" """Restore state for entities with expire_after set."""
await super().async_added_to_hass()
if ( if (
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None (expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
and expire_after > 0 and expire_after > 0
@ -174,7 +173,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
and (last_sensor_data := await self.async_get_last_sensor_data()) and (last_sensor_data := await self.async_get_last_sensor_data())
is not None is not None
# We might have set up a trigger already after subscribing from # We might have set up a trigger already after subscribing from
# super().async_added_to_hass(), then we should not restore state # MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger and not self._expiration_trigger
): ):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after) expiration_at = last_state.last_changed + timedelta(seconds=expire_after)

View file

@ -34,12 +34,32 @@ def mock_try_connection():
def mock_try_connection_success(): def mock_try_connection_success():
"""Mock the try connection method with success.""" """Mock the try connection method with success."""
_mid = 1
def get_mid():
nonlocal _mid
_mid += 1
return _mid
def loop_start(): def loop_start():
"""Simulate connect on loop start.""" """Simulate connect on loop start."""
mock_client().on_connect(mock_client, None, None, 0) mock_client().on_connect(mock_client, None, None, 0)
def _subscribe(topic, qos=0):
mid = get_mid()
mock_client().on_subscribe(mock_client, 0, mid)
return (0, mid)
def _unsubscribe(topic):
mid = get_mid()
mock_client().on_unsubscribe(mock_client, 0, mid)
return (0, mid)
with patch("paho.mqtt.client.Client") as mock_client: with patch("paho.mqtt.client.Client") as mock_client:
mock_client().loop_start = loop_start mock_client().loop_start = loop_start
mock_client().subscribe = _subscribe
mock_client().unsubscribe = _unsubscribe
yield mock_client() yield mock_client()

View file

@ -550,6 +550,58 @@ async def test_rapid_rediscover_unique(hass, mqtt_mock, caplog):
assert events[3].data["old_state"] is None assert events[3].data["old_state"] is None
async def test_rapid_reconfigure(hass, mqtt_mock, caplog):
"""Test immediate reconfigure of added component."""
events = []
@ha.callback
def callback(event):
"""Verify event got called."""
events.append(event)
hass.bus.async_listen(EVENT_STATE_CHANGED, callback)
# Discovery immediately followed by reconfig
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
async_fire_mqtt_message(
hass,
"homeassistant/binary_sensor/bla/config",
'{ "name": "Beer", "state_topic": "test-topic1" }',
)
async_fire_mqtt_message(
hass,
"homeassistant/binary_sensor/bla/config",
'{ "name": "Milk", "state_topic": "test-topic2" }',
)
async_fire_mqtt_message(
hass,
"homeassistant/binary_sensor/bla/config",
'{ "name": "Wine", "state_topic": "test-topic3" }',
)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids("binary_sensor")) == 1
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert len(events) == 3
# Add the entity
assert events[0].data["entity_id"] == "binary_sensor.beer"
assert events[0].data["old_state"] is None
assert events[0].data["new_state"].attributes["friendly_name"] == "Beer"
# Update the entity
assert events[1].data["entity_id"] == "binary_sensor.beer"
assert events[1].data["new_state"] is not None
assert events[1].data["old_state"] is not None
assert events[1].data["new_state"].attributes["friendly_name"] == "Milk"
# Update the entity
assert events[2].data["entity_id"] == "binary_sensor.beer"
assert events[2].data["new_state"] is not None
assert events[2].data["old_state"] is not None
assert events[2].data["new_state"].attributes["friendly_name"] == "Wine"
async def test_duplicate_removal(hass, mqtt_mock, caplog): async def test_duplicate_removal(hass, mqtt_mock, caplog):
"""Test for a non duplicate component.""" """Test for a non duplicate component."""
async_fire_mqtt_message( async_fire_mqtt_message(

View file

@ -618,6 +618,8 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
) )
mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics
mqtt_component_mock._mqttc = mqtt_client_mock mqtt_component_mock._mqttc = mqtt_client_mock
# connected set to True to get a more realistics behavior when subscribing
hass.data["mqtt"].connected = True
hass.data["mqtt"] = mqtt_component_mock hass.data["mqtt"] = mqtt_component_mock
component = hass.data["mqtt"] component = hass.data["mqtt"]