Fix race when handling rapid succession of MQTT discovery messages (#68785)
Co-authored-by: jbouwh <jan@jbsoft.nl>
This commit is contained in:
parent
3d378449e8
commit
7e8d52e5a3
7 changed files with 103 additions and 14 deletions
|
@ -118,16 +118,15 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
||||||
|
|
||||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def mqtt_async_added_to_hass(self) -> None:
|
||||||
"""Restore state for entities with expire_after set."""
|
"""Restore state for entities with expire_after set."""
|
||||||
await super().async_added_to_hass()
|
|
||||||
if (
|
if (
|
||||||
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
|
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
|
||||||
and expire_after > 0
|
and expire_after > 0
|
||||||
and (last_state := await self.async_get_last_state()) is not None
|
and (last_state := await self.async_get_last_state()) is not None
|
||||||
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
|
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
|
||||||
# We might have set up a trigger already after subscribing from
|
# We might have set up a trigger already after subscribing from
|
||||||
# super().async_added_to_hass(), then we should not restore state
|
# MqttEntity.async_added_to_hass(), then we should not restore state
|
||||||
and not self._expiration_trigger
|
and not self._expiration_trigger
|
||||||
):
|
):
|
||||||
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
|
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from abc import abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol, final
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -572,9 +572,7 @@ class MqttDiscoveryUpdate(Entity):
|
||||||
else:
|
else:
|
||||||
# Non-empty, unchanged payload: Ignore to avoid changing states
|
# Non-empty, unchanged payload: Ignore to avoid changing states
|
||||||
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
|
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
|
||||||
async_dispatcher_send(
|
self.async_send_discovery_done()
|
||||||
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
|
|
||||||
)
|
|
||||||
|
|
||||||
if discovery_hash:
|
if discovery_hash:
|
||||||
debug_info.add_entity_discovery_data(
|
debug_info.add_entity_discovery_data(
|
||||||
|
@ -587,9 +585,18 @@ class MqttDiscoveryUpdate(Entity):
|
||||||
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
|
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
|
||||||
discovery_callback,
|
discovery_callback,
|
||||||
)
|
)
|
||||||
async_dispatcher_send(
|
|
||||||
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
|
@callback
|
||||||
)
|
def async_send_discovery_done(self) -> None:
|
||||||
|
"""Acknowledge a discovery message has been handled."""
|
||||||
|
discovery_hash = (
|
||||||
|
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
|
||||||
|
)
|
||||||
|
if not discovery_hash:
|
||||||
|
return
|
||||||
|
async_dispatcher_send(
|
||||||
|
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
|
||||||
|
)
|
||||||
|
|
||||||
async def async_removed_from_registry(self) -> None:
|
async def async_removed_from_registry(self) -> None:
|
||||||
"""Clear retained discovery topic in broker."""
|
"""Clear retained discovery topic in broker."""
|
||||||
|
@ -723,11 +730,20 @@ class MqttEntity(
|
||||||
self.hass, self, self._config, self._entity_id_format
|
self.hass, self, self._config, self._entity_id_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@final
|
||||||
async def async_added_to_hass(self):
|
async def async_added_to_hass(self):
|
||||||
"""Subscribe mqtt events."""
|
"""Subscribe to MQTT events."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
self._prepare_subscribe_topics()
|
self._prepare_subscribe_topics()
|
||||||
await self._subscribe_topics()
|
await self._subscribe_topics()
|
||||||
|
await self.mqtt_async_added_to_hass()
|
||||||
|
self.async_send_discovery_done()
|
||||||
|
|
||||||
|
async def mqtt_async_added_to_hass(self):
|
||||||
|
"""Call before the discovery message is acknowledged.
|
||||||
|
|
||||||
|
To be extended by subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
async def discovery_update(self, discovery_payload):
|
async def discovery_update(self, discovery_payload):
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
|
|
|
@ -110,6 +110,7 @@ class MqttScene(
|
||||||
async def async_added_to_hass(self):
|
async def async_added_to_hass(self):
|
||||||
"""Subscribe to MQTT events."""
|
"""Subscribe to MQTT events."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
|
self.async_send_discovery_done()
|
||||||
|
|
||||||
async def discovery_update(self, discovery_payload):
|
async def discovery_update(self, discovery_payload):
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
|
|
|
@ -163,9 +163,8 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
||||||
|
|
||||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def mqtt_async_added_to_hass(self) -> None:
|
||||||
"""Restore state for entities with expire_after set."""
|
"""Restore state for entities with expire_after set."""
|
||||||
await super().async_added_to_hass()
|
|
||||||
if (
|
if (
|
||||||
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
|
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
|
||||||
and expire_after > 0
|
and expire_after > 0
|
||||||
|
@ -174,7 +173,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
||||||
and (last_sensor_data := await self.async_get_last_sensor_data())
|
and (last_sensor_data := await self.async_get_last_sensor_data())
|
||||||
is not None
|
is not None
|
||||||
# We might have set up a trigger already after subscribing from
|
# We might have set up a trigger already after subscribing from
|
||||||
# super().async_added_to_hass(), then we should not restore state
|
# MqttEntity.async_added_to_hass(), then we should not restore state
|
||||||
and not self._expiration_trigger
|
and not self._expiration_trigger
|
||||||
):
|
):
|
||||||
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
|
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)
|
||||||
|
|
|
@ -34,12 +34,32 @@ def mock_try_connection():
|
||||||
def mock_try_connection_success():
|
def mock_try_connection_success():
|
||||||
"""Mock the try connection method with success."""
|
"""Mock the try connection method with success."""
|
||||||
|
|
||||||
|
_mid = 1
|
||||||
|
|
||||||
|
def get_mid():
|
||||||
|
nonlocal _mid
|
||||||
|
_mid += 1
|
||||||
|
return _mid
|
||||||
|
|
||||||
def loop_start():
|
def loop_start():
|
||||||
"""Simulate connect on loop start."""
|
"""Simulate connect on loop start."""
|
||||||
mock_client().on_connect(mock_client, None, None, 0)
|
mock_client().on_connect(mock_client, None, None, 0)
|
||||||
|
|
||||||
|
def _subscribe(topic, qos=0):
|
||||||
|
mid = get_mid()
|
||||||
|
mock_client().on_subscribe(mock_client, 0, mid)
|
||||||
|
return (0, mid)
|
||||||
|
|
||||||
|
def _unsubscribe(topic):
|
||||||
|
mid = get_mid()
|
||||||
|
mock_client().on_unsubscribe(mock_client, 0, mid)
|
||||||
|
return (0, mid)
|
||||||
|
|
||||||
with patch("paho.mqtt.client.Client") as mock_client:
|
with patch("paho.mqtt.client.Client") as mock_client:
|
||||||
mock_client().loop_start = loop_start
|
mock_client().loop_start = loop_start
|
||||||
|
mock_client().subscribe = _subscribe
|
||||||
|
mock_client().unsubscribe = _unsubscribe
|
||||||
|
|
||||||
yield mock_client()
|
yield mock_client()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -550,6 +550,58 @@ async def test_rapid_rediscover_unique(hass, mqtt_mock, caplog):
|
||||||
assert events[3].data["old_state"] is None
|
assert events[3].data["old_state"] is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rapid_reconfigure(hass, mqtt_mock, caplog):
|
||||||
|
"""Test immediate reconfigure of added component."""
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
|
def callback(event):
|
||||||
|
"""Verify event got called."""
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
hass.bus.async_listen(EVENT_STATE_CHANGED, callback)
|
||||||
|
|
||||||
|
# Discovery immediately followed by reconfig
|
||||||
|
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
|
||||||
|
async_fire_mqtt_message(
|
||||||
|
hass,
|
||||||
|
"homeassistant/binary_sensor/bla/config",
|
||||||
|
'{ "name": "Beer", "state_topic": "test-topic1" }',
|
||||||
|
)
|
||||||
|
async_fire_mqtt_message(
|
||||||
|
hass,
|
||||||
|
"homeassistant/binary_sensor/bla/config",
|
||||||
|
'{ "name": "Milk", "state_topic": "test-topic2" }',
|
||||||
|
)
|
||||||
|
async_fire_mqtt_message(
|
||||||
|
hass,
|
||||||
|
"homeassistant/binary_sensor/bla/config",
|
||||||
|
'{ "name": "Wine", "state_topic": "test-topic3" }',
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(hass.states.async_entity_ids("binary_sensor")) == 1
|
||||||
|
state = hass.states.get("binary_sensor.beer")
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
assert len(events) == 3
|
||||||
|
# Add the entity
|
||||||
|
assert events[0].data["entity_id"] == "binary_sensor.beer"
|
||||||
|
assert events[0].data["old_state"] is None
|
||||||
|
assert events[0].data["new_state"].attributes["friendly_name"] == "Beer"
|
||||||
|
# Update the entity
|
||||||
|
assert events[1].data["entity_id"] == "binary_sensor.beer"
|
||||||
|
assert events[1].data["new_state"] is not None
|
||||||
|
assert events[1].data["old_state"] is not None
|
||||||
|
assert events[1].data["new_state"].attributes["friendly_name"] == "Milk"
|
||||||
|
# Update the entity
|
||||||
|
assert events[2].data["entity_id"] == "binary_sensor.beer"
|
||||||
|
assert events[2].data["new_state"] is not None
|
||||||
|
assert events[2].data["old_state"] is not None
|
||||||
|
assert events[2].data["new_state"].attributes["friendly_name"] == "Wine"
|
||||||
|
|
||||||
|
|
||||||
async def test_duplicate_removal(hass, mqtt_mock, caplog):
|
async def test_duplicate_removal(hass, mqtt_mock, caplog):
|
||||||
"""Test for a non duplicate component."""
|
"""Test for a non duplicate component."""
|
||||||
async_fire_mqtt_message(
|
async_fire_mqtt_message(
|
||||||
|
|
|
@ -618,6 +618,8 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
|
||||||
)
|
)
|
||||||
mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics
|
mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics
|
||||||
mqtt_component_mock._mqttc = mqtt_client_mock
|
mqtt_component_mock._mqttc = mqtt_client_mock
|
||||||
|
# connected set to True to get a more realistics behavior when subscribing
|
||||||
|
hass.data["mqtt"].connected = True
|
||||||
|
|
||||||
hass.data["mqtt"] = mqtt_component_mock
|
hass.data["mqtt"] = mqtt_component_mock
|
||||||
component = hass.data["mqtt"]
|
component = hass.data["mqtt"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue