Avoid redundant calls to async_ha_write_state in MQTT (binary) sensor (#100438)

* Only call `async_ha_write_state` on changes.

* Make helper class

* Use UndefinedType

* Remove del

* Integrate monitor into MqttEntity

* Track extra state attributes and availability

* Add `__slots__`

* Add monitor to MqttAttributes and MqttAvailability

* Write out loop

* Add test

* Make common test and parameterize

* Add test for last_reset attribute

* MqttMonitorEntity base class

* Rename attr and update docstr `track` method.

* correction doct

* Implement as a decorator

* Move tracking functions into decorator

* Rename decorator

* Follow up comment
This commit is contained in:
Jan Bouwhuis 2023-09-21 13:33:26 +02:00 committed by GitHub
parent 11c4c37cf9
commit aed3ba3acd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 11 deletions

View file

@ -43,9 +43,9 @@ from .mixins import (
MqttAvailability, MqttAvailability,
MqttEntity, MqttEntity,
async_setup_entry_helper, async_setup_entry_helper,
write_state_on_attr_change,
) )
from .models import MqttValueTemplate, ReceiveMessage from .models import MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -191,6 +191,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_message_received(msg: ReceiveMessage) -> None: def state_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT state message.""" """Handle a new received MQTT state message."""
# auto-expire enabled? # auto-expire enabled?
@ -257,8 +258,6 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
self.hass, off_delay, off_delay_listener self.hass, off_delay, off_delay_listener
) )
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from functools import partial from functools import partial, wraps
import logging import logging
from typing import TYPE_CHECKING, Any, Protocol, cast, final from typing import TYPE_CHECKING, Any, Protocol, cast, final
@ -101,6 +101,7 @@ from .discovery import (
set_discovery_hash, set_discovery_hash,
) )
from .models import ( from .models import (
MessageCallbackType,
MqttValueTemplate, MqttValueTemplate,
PublishPayloadType, PublishPayloadType,
ReceiveMessage, ReceiveMessage,
@ -346,6 +347,41 @@ def init_entity_id_from_config(
) )
def write_state_on_attr_change(
entity: Entity, attributes: set[str]
) -> Callable[[MessageCallbackType], MessageCallbackType]:
"""Wrap an MQTT message callback to track state attribute changes."""
def _attrs_have_changed(tracked_attrs: dict[str, Any]) -> bool:
"""Return True if attributes on entity changed or if update is forced."""
if not (write_state := (getattr(entity, "_attr_force_update", False))):
for attribute, last_value in tracked_attrs.items():
if getattr(entity, attribute, UNDEFINED) != last_value:
write_state = True
break
return write_state
def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Track attributes for write state requests."""
tracked_attrs: dict[str, Any] = {
attribute: getattr(entity, attribute, UNDEFINED)
for attribute in attributes
}
msg_callback(msg)
if not _attrs_have_changed(tracked_attrs):
return
mqtt_data = get_mqtt_data(entity.hass)
mqtt_data.state_write_requests.write_state_request(entity)
return wrapper
return _decorator
class MqttAttributes(Entity): class MqttAttributes(Entity):
"""Mixin used for platforms that support JSON attributes.""" """Mixin used for platforms that support JSON attributes."""
@ -379,6 +415,7 @@ class MqttAttributes(Entity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_extra_state_attributes"})
def attributes_message_received(msg: ReceiveMessage) -> None: def attributes_message_received(msg: ReceiveMessage) -> None:
try: try:
payload = attr_tpl(msg.payload) payload = attr_tpl(msg.payload)
@ -391,9 +428,6 @@ class MqttAttributes(Entity):
and k not in self._attributes_extra_blocked and k not in self._attributes_extra_blocked
} }
self._attr_extra_state_attributes = filtered_dict self._attr_extra_state_attributes = filtered_dict
get_mqtt_data(self.hass).state_write_requests.write_state_request(
self
)
else: else:
_LOGGER.warning("JSON result was not a dictionary") _LOGGER.warning("JSON result was not a dictionary")
except ValueError: except ValueError:
@ -488,6 +522,7 @@ class MqttAvailability(Entity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"available"})
def availability_message_received(msg: ReceiveMessage) -> None: def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message.""" """Handle a new received MQTT availability message."""
topic = msg.topic topic = msg.topic
@ -500,8 +535,6 @@ class MqttAvailability(Entity):
self._available[topic] = False self._available[topic] = False
self._available_latest = False self._available_latest = False
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self._available = { self._available = {
topic: (self._available[topic] if topic in self._available else False) topic: (self._available[topic] if topic in self._available else False)
for topic in self._avail_topics for topic in self._avail_topics

View file

@ -45,6 +45,7 @@ from .mixins import (
MqttAvailability, MqttAvailability,
MqttEntity, MqttEntity,
async_setup_entry_helper, async_setup_entry_helper,
write_state_on_attr_change,
) )
from .models import ( from .models import (
MqttValueTemplate, MqttValueTemplate,
@ -52,7 +53,6 @@ from .models import (
ReceiveMessage, ReceiveMessage,
ReceivePayloadType, ReceivePayloadType,
) )
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -287,13 +287,13 @@ class MqttSensor(MqttEntity, RestoreSensor):
) )
@callback @callback
@write_state_on_attr_change(self, {"_attr_native_value", "_attr_last_reset"})
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None: def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
_update_state(msg) _update_state(msg)
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config: if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
_update_last_reset(msg) _update_last_reset(msg)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],

View file

@ -47,6 +47,7 @@ from .test_common import (
help_test_reloadable, help_test_reloadable,
help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_via_mqtt_json_message,
help_test_setting_attribute_with_template, help_test_setting_attribute_with_template,
help_test_skipped_async_ha_write_state,
help_test_unique_id, help_test_unique_id,
help_test_unload_config_entry_with_platform, help_test_unload_config_entry_with_platform,
help_test_update_with_json_attrs_bad_json, help_test_update_with_json_attrs_bad_json,
@ -1248,3 +1249,38 @@ async def test_entity_name(
await help_test_entity_name( await help_test_entity_name(
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
) )
@pytest.mark.parametrize(
"hass_config",
[
help_custom_config(
binary_sensor.DOMAIN,
DEFAULT_CONFIG,
(
{
"availability_topic": "availability-topic",
"json_attributes_topic": "json-attributes-topic",
},
),
)
],
)
@pytest.mark.parametrize(
("topic", "payload1", "payload2"),
[
("test-topic", "ON", "OFF"),
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
],
)
async def test_skipped_async_ha_write_state(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
topic: str,
payload1: str,
payload2: str,
) -> None:
"""Test a write state command is only called when there is change."""
await mqtt_mock_entry()
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)

View file

@ -1925,3 +1925,28 @@ async def help_test_discovery_setup(
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get(f"{domain}.{name}") state = hass.states.get(f"{domain}.{name}")
assert state and state.state is not None assert state and state.state is not None
async def help_test_skipped_async_ha_write_state(
hass: HomeAssistant, topic: str, payload1: str, payload2: str
) -> None:
"""Test entity.async_ha_write_state is only called on changes."""
with patch(
"homeassistant.components.mqtt.mixins.MqttEntity.async_write_ha_state"
) as mock_async_ha_write_state:
assert len(mock_async_ha_write_state.mock_calls) == 0
async_fire_mqtt_message(hass, topic, payload1)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 1
async_fire_mqtt_message(hass, topic, payload1)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 1
async_fire_mqtt_message(hass, topic, payload2)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 2
async_fire_mqtt_message(hass, topic, payload2)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 2

View file

@ -60,6 +60,7 @@ from .test_common import (
help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_via_mqtt_json_message,
help_test_setting_attribute_with_template, help_test_setting_attribute_with_template,
help_test_setting_blocked_attribute_via_mqtt_json_message, help_test_setting_blocked_attribute_via_mqtt_json_message,
help_test_skipped_async_ha_write_state,
help_test_unique_id, help_test_unique_id,
help_test_unload_config_entry_with_platform, help_test_unload_config_entry_with_platform,
help_test_update_with_json_attrs_bad_json, help_test_update_with_json_attrs_bad_json,
@ -1437,3 +1438,45 @@ async def test_entity_name(
await help_test_entity_name( await help_test_entity_name(
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
) )
@pytest.mark.parametrize(
"hass_config",
[
help_custom_config(
sensor.DOMAIN,
DEFAULT_CONFIG,
(
{
"availability_topic": "availability-topic",
"json_attributes_topic": "json-attributes-topic",
"value_template": "{{ value_json.state }}",
"last_reset_value_template": "{{ value_json.last_reset }}",
},
),
)
],
)
@pytest.mark.parametrize(
("topic", "payload1", "payload2"),
[
("test-topic", '{"state":"val1"}', '{"state":"val2"}'),
(
"test-topic",
'{"last_reset":"2023-09-15 15:11:03"}',
'{"last_reset":"2023-09-16 15:11:02"}',
),
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
],
)
async def test_skipped_async_ha_write_state(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
topic: str,
payload1: str,
payload2: str,
) -> None:
"""Test a write state command is only called when there is change."""
await mqtt_mock_entry()
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)