Fix race when handling rapid succession of MQTT discovery messages (#68785)
Co-authored-by: jbouwh <jan@jbsoft.nl>
This commit is contained in:
parent
3d378449e8
commit
7e8d52e5a3
7 changed files with 103 additions and 14 deletions
|
@ -118,16 +118,15 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
|||
|
||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
async def mqtt_async_added_to_hass(self) -> None:
|
||||
"""Restore state for entities with expire_after set."""
|
||||
await super().async_added_to_hass()
|
||||
if (
|
||||
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
|
||||
and expire_after > 0
|
||||
and (last_state := await self.async_get_last_state()) is not None
|
||||
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
|
||||
# We might have set up a trigger already after subscribing from
|
||||
# super().async_added_to_hass(), then we should not restore state
|
||||
# MqttEntity.async_added_to_hass(), then we should not restore state
|
||||
and not self._expiration_trigger
|
||||
):
|
||||
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
|
||||
|
|
|
@ -5,7 +5,7 @@ from abc import abstractmethod
|
|||
from collections.abc import Callable
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -572,9 +572,7 @@ class MqttDiscoveryUpdate(Entity):
|
|||
else:
|
||||
# Non-empty, unchanged payload: Ignore to avoid changing states
|
||||
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
|
||||
async_dispatcher_send(
|
||||
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
|
||||
)
|
||||
self.async_send_discovery_done()
|
||||
|
||||
if discovery_hash:
|
||||
debug_info.add_entity_discovery_data(
|
||||
|
@ -587,9 +585,18 @@ class MqttDiscoveryUpdate(Entity):
|
|||
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
|
||||
discovery_callback,
|
||||
)
|
||||
async_dispatcher_send(
|
||||
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_send_discovery_done(self) -> None:
|
||||
"""Acknowledge a discovery message has been handled."""
|
||||
discovery_hash = (
|
||||
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
|
||||
)
|
||||
if not discovery_hash:
|
||||
return
|
||||
async_dispatcher_send(
|
||||
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
|
||||
)
|
||||
|
||||
async def async_removed_from_registry(self) -> None:
|
||||
"""Clear retained discovery topic in broker."""
|
||||
|
@ -723,11 +730,20 @@ class MqttEntity(
|
|||
self.hass, self, self._config, self._entity_id_format
|
||||
)
|
||||
|
||||
@final
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe mqtt events."""
|
||||
"""Subscribe to MQTT events."""
|
||||
await super().async_added_to_hass()
|
||||
self._prepare_subscribe_topics()
|
||||
await self._subscribe_topics()
|
||||
await self.mqtt_async_added_to_hass()
|
||||
self.async_send_discovery_done()
|
||||
|
||||
async def mqtt_async_added_to_hass(self):
|
||||
"""Call before the discovery message is acknowledged.
|
||||
|
||||
To be extended by subclasses.
|
||||
"""
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
"""Handle updated discovery message."""
|
||||
|
|
|
@ -110,6 +110,7 @@ class MqttScene(
|
|||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await super().async_added_to_hass()
|
||||
self.async_send_discovery_done()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
"""Handle updated discovery message."""
|
||||
|
|
|
@ -163,9 +163,8 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
|||
|
||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
async def mqtt_async_added_to_hass(self) -> None:
|
||||
"""Restore state for entities with expire_after set."""
|
||||
await super().async_added_to_hass()
|
||||
if (
|
||||
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
|
||||
and expire_after > 0
|
||||
|
@ -174,7 +173,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
|||
and (last_sensor_data := await self.async_get_last_sensor_data())
|
||||
is not None
|
||||
# We might have set up a trigger already after subscribing from
|
||||
# super().async_added_to_hass(), then we should not restore state
|
||||
# MqttEntity.async_added_to_hass(), then we should not restore state
|
||||
and not self._expiration_trigger
|
||||
):
|
||||
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
|
||||
|
|
|
@ -34,12 +34,32 @@ def mock_try_connection():
|
|||
def mock_try_connection_success():
|
||||
"""Mock the try connection method with success."""
|
||||
|
||||
_mid = 1
|
||||
|
||||
def get_mid():
|
||||
nonlocal _mid
|
||||
_mid += 1
|
||||
return _mid
|
||||
|
||||
def loop_start():
|
||||
"""Simulate connect on loop start."""
|
||||
mock_client().on_connect(mock_client, None, None, 0)
|
||||
|
||||
def _subscribe(topic, qos=0):
|
||||
mid = get_mid()
|
||||
mock_client().on_subscribe(mock_client, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
def _unsubscribe(topic):
|
||||
mid = get_mid()
|
||||
mock_client().on_unsubscribe(mock_client, 0, mid)
|
||||
return (0, mid)
|
||||
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().loop_start = loop_start
|
||||
mock_client().subscribe = _subscribe
|
||||
mock_client().unsubscribe = _unsubscribe
|
||||
|
||||
yield mock_client()
|
||||
|
||||
|
||||
|
|
|
@ -550,6 +550,58 @@ async def test_rapid_rediscover_unique(hass, mqtt_mock, caplog):
|
|||
assert events[3].data["old_state"] is None
|
||||
|
||||
|
||||
async def test_rapid_reconfigure(hass, mqtt_mock, caplog):
|
||||
"""Test immediate reconfigure of added component."""
|
||||
|
||||
events = []
|
||||
|
||||
@ha.callback
|
||||
def callback(event):
|
||||
"""Verify event got called."""
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen(EVENT_STATE_CHANGED, callback)
|
||||
|
||||
# Discovery immediately followed by reconfig
|
||||
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
|
||||
async_fire_mqtt_message(
|
||||
hass,
|
||||
"homeassistant/binary_sensor/bla/config",
|
||||
'{ "name": "Beer", "state_topic": "test-topic1" }',
|
||||
)
|
||||
async_fire_mqtt_message(
|
||||
hass,
|
||||
"homeassistant/binary_sensor/bla/config",
|
||||
'{ "name": "Milk", "state_topic": "test-topic2" }',
|
||||
)
|
||||
async_fire_mqtt_message(
|
||||
hass,
|
||||
"homeassistant/binary_sensor/bla/config",
|
||||
'{ "name": "Wine", "state_topic": "test-topic3" }',
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass.states.async_entity_ids("binary_sensor")) == 1
|
||||
state = hass.states.get("binary_sensor.beer")
|
||||
assert state is not None
|
||||
|
||||
assert len(events) == 3
|
||||
# Add the entity
|
||||
assert events[0].data["entity_id"] == "binary_sensor.beer"
|
||||
assert events[0].data["old_state"] is None
|
||||
assert events[0].data["new_state"].attributes["friendly_name"] == "Beer"
|
||||
# Update the entity
|
||||
assert events[1].data["entity_id"] == "binary_sensor.beer"
|
||||
assert events[1].data["new_state"] is not None
|
||||
assert events[1].data["old_state"] is not None
|
||||
assert events[1].data["new_state"].attributes["friendly_name"] == "Milk"
|
||||
# Update the entity
|
||||
assert events[2].data["entity_id"] == "binary_sensor.beer"
|
||||
assert events[2].data["new_state"] is not None
|
||||
assert events[2].data["old_state"] is not None
|
||||
assert events[2].data["new_state"].attributes["friendly_name"] == "Wine"
|
||||
|
||||
|
||||
async def test_duplicate_removal(hass, mqtt_mock, caplog):
|
||||
"""Test for a non duplicate component."""
|
||||
async_fire_mqtt_message(
|
||||
|
|
|
@ -618,6 +618,8 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
|
|||
)
|
||||
mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics
|
||||
mqtt_component_mock._mqttc = mqtt_client_mock
|
||||
# connected set to True to get a more realistics behavior when subscribing
|
||||
hass.data["mqtt"].connected = True
|
||||
|
||||
hass.data["mqtt"] = mqtt_component_mock
|
||||
component = hass.data["mqtt"]
|
||||
|
|
Loading…
Add table
Reference in a new issue