Fix race when handling rapid succession of MQTT discovery messages (#68785)

Co-authored-by: jbouwh <jan@jbsoft.nl>
This commit is contained in:
Erik Montnemery 2022-03-30 05:26:11 +02:00 committed by GitHub
parent 3d378449e8
commit 7e8d52e5a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 14 deletions

View file

@ -118,16 +118,15 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
async def async_added_to_hass(self) -> None:
async def mqtt_async_added_to_hass(self) -> None:
"""Restore state for entities with expire_after set."""
await super().async_added_to_hass()
if (
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
and expire_after > 0
and (last_state := await self.async_get_last_state()) is not None
and last_state.state not in [STATE_UNKNOWN, STATE_UNAVAILABLE]
# We might have set up a trigger already after subscribing from
# super().async_added_to_hass(), then we should not restore state
# MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger
):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)

View file

@ -5,7 +5,7 @@ from abc import abstractmethod
from collections.abc import Callable
import json
import logging
from typing import Any, Protocol
from typing import Any, Protocol, final
import voluptuous as vol
@ -572,9 +572,7 @@ class MqttDiscoveryUpdate(Entity):
else:
# Non-empty, unchanged payload: Ignore to avoid changing states
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
self.async_send_discovery_done()
if discovery_hash:
debug_info.add_entity_discovery_data(
@ -587,9 +585,18 @@ class MqttDiscoveryUpdate(Entity):
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
discovery_callback,
)
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
@callback
def async_send_discovery_done(self) -> None:
"""Acknowledge a discovery message has been handled."""
discovery_hash = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
)
if not discovery_hash:
return
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
async def async_removed_from_registry(self) -> None:
"""Clear retained discovery topic in broker."""
@ -723,11 +730,20 @@ class MqttEntity(
self.hass, self, self._config, self._entity_id_format
)
@final
async def async_added_to_hass(self):
"""Subscribe mqtt events."""
"""Subscribe to MQTT events."""
await super().async_added_to_hass()
self._prepare_subscribe_topics()
await self._subscribe_topics()
await self.mqtt_async_added_to_hass()
self.async_send_discovery_done()
async def mqtt_async_added_to_hass(self):
"""Call before the discovery message is acknowledged.
To be extended by subclasses.
"""
async def discovery_update(self, discovery_payload):
"""Handle updated discovery message."""

View file

@ -110,6 +110,7 @@ class MqttScene(
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await super().async_added_to_hass()
self.async_send_discovery_done()
async def discovery_update(self, discovery_payload):
"""Handle updated discovery message."""

View file

@ -163,9 +163,8 @@ class MqttSensor(MqttEntity, RestoreSensor):
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
async def async_added_to_hass(self) -> None:
async def mqtt_async_added_to_hass(self) -> None:
"""Restore state for entities with expire_after set."""
await super().async_added_to_hass()
if (
(expire_after := self._config.get(CONF_EXPIRE_AFTER)) is not None
and expire_after > 0
@ -174,7 +173,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
and (last_sensor_data := await self.async_get_last_sensor_data())
is not None
# We might have set up a trigger already after subscribing from
# super().async_added_to_hass(), then we should not restore state
# MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger
):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after)

View file

@ -34,12 +34,32 @@ def mock_try_connection():
def mock_try_connection_success():
"""Mock the try connection method with success."""
_mid = 1
def get_mid():
nonlocal _mid
_mid += 1
return _mid
def loop_start():
"""Simulate connect on loop start."""
mock_client().on_connect(mock_client, None, None, 0)
def _subscribe(topic, qos=0):
mid = get_mid()
mock_client().on_subscribe(mock_client, 0, mid)
return (0, mid)
def _unsubscribe(topic):
mid = get_mid()
mock_client().on_unsubscribe(mock_client, 0, mid)
return (0, mid)
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().loop_start = loop_start
mock_client().subscribe = _subscribe
mock_client().unsubscribe = _unsubscribe
yield mock_client()

View file

@ -550,6 +550,58 @@ async def test_rapid_rediscover_unique(hass, mqtt_mock, caplog):
assert events[3].data["old_state"] is None
async def test_rapid_reconfigure(hass, mqtt_mock, caplog):
"""Test immediate reconfigure of added component."""
events = []
@ha.callback
def callback(event):
"""Verify event got called."""
events.append(event)
hass.bus.async_listen(EVENT_STATE_CHANGED, callback)
# Discovery immediately followed by reconfig
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
async_fire_mqtt_message(
hass,
"homeassistant/binary_sensor/bla/config",
'{ "name": "Beer", "state_topic": "test-topic1" }',
)
async_fire_mqtt_message(
hass,
"homeassistant/binary_sensor/bla/config",
'{ "name": "Milk", "state_topic": "test-topic2" }',
)
async_fire_mqtt_message(
hass,
"homeassistant/binary_sensor/bla/config",
'{ "name": "Wine", "state_topic": "test-topic3" }',
)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids("binary_sensor")) == 1
state = hass.states.get("binary_sensor.beer")
assert state is not None
assert len(events) == 3
# Add the entity
assert events[0].data["entity_id"] == "binary_sensor.beer"
assert events[0].data["old_state"] is None
assert events[0].data["new_state"].attributes["friendly_name"] == "Beer"
# Update the entity
assert events[1].data["entity_id"] == "binary_sensor.beer"
assert events[1].data["new_state"] is not None
assert events[1].data["old_state"] is not None
assert events[1].data["new_state"].attributes["friendly_name"] == "Milk"
# Update the entity
assert events[2].data["entity_id"] == "binary_sensor.beer"
assert events[2].data["new_state"] is not None
assert events[2].data["old_state"] is not None
assert events[2].data["new_state"].attributes["friendly_name"] == "Wine"
async def test_duplicate_removal(hass, mqtt_mock, caplog):
"""Test for a non duplicate component."""
async_fire_mqtt_message(

View file

@ -618,6 +618,8 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
)
mqtt_component_mock.conf = hass.data["mqtt"].conf # For diagnostics
mqtt_component_mock._mqttc = mqtt_client_mock
# connected set to True to get a more realistics behavior when subscribing
hass.data["mqtt"].connected = True
hass.data["mqtt"] = mqtt_component_mock
component = hass.data["mqtt"]