Avoid redundant calls to async_ha_write_state in MQTT (binary) sensor (#100438)

* Only call `async_ha_write_state` on changes.

* Make helper class

* Use UndefinedType

* Remove del

* Integrate monitor into MqttEntity

* Track extra state attributes and availability

* Add `__slots__`

* Add monitor to MqttAttributes and MqttAvailability

* Write out loop

* Add test

* Make common test and parameterize

* Add test for last_reset attribute

* MqttMonitorEntity base class

* Rename attr and update docstr `track` method.

* correction doct

* Implement as a decorator

* Move tracking functions into decorator

* Rename decorator

* Follow up comment
This commit is contained in:
Jan Bouwhuis 2023-09-21 13:33:26 +02:00 committed by GitHub
parent 11c4c37cf9
commit aed3ba3acd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 11 deletions

View file

@ -43,9 +43,9 @@ from .mixins import (
MqttAvailability,
MqttEntity,
async_setup_entry_helper,
write_state_on_attr_change,
)
from .models import MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__)
@ -191,6 +191,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT state message."""
# auto-expire enabled?
@ -257,8 +258,6 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
self.hass, off_delay, off_delay_listener
)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable, Coroutine
from functools import partial
from functools import partial, wraps
import logging
from typing import TYPE_CHECKING, Any, Protocol, cast, final
@ -101,6 +101,7 @@ from .discovery import (
set_discovery_hash,
)
from .models import (
MessageCallbackType,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
@ -346,6 +347,41 @@ def init_entity_id_from_config(
)
def write_state_on_attr_change(
entity: Entity, attributes: set[str]
) -> Callable[[MessageCallbackType], MessageCallbackType]:
"""Wrap an MQTT message callback to track state attribute changes."""
def _attrs_have_changed(tracked_attrs: dict[str, Any]) -> bool:
"""Return True if attributes on entity changed or if update is forced."""
if not (write_state := (getattr(entity, "_attr_force_update", False))):
for attribute, last_value in tracked_attrs.items():
if getattr(entity, attribute, UNDEFINED) != last_value:
write_state = True
break
return write_state
def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Track attributes for write state requests."""
tracked_attrs: dict[str, Any] = {
attribute: getattr(entity, attribute, UNDEFINED)
for attribute in attributes
}
msg_callback(msg)
if not _attrs_have_changed(tracked_attrs):
return
mqtt_data = get_mqtt_data(entity.hass)
mqtt_data.state_write_requests.write_state_request(entity)
return wrapper
return _decorator
class MqttAttributes(Entity):
"""Mixin used for platforms that support JSON attributes."""
@ -379,6 +415,7 @@ class MqttAttributes(Entity):
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_extra_state_attributes"})
def attributes_message_received(msg: ReceiveMessage) -> None:
try:
payload = attr_tpl(msg.payload)
@ -391,9 +428,6 @@ class MqttAttributes(Entity):
and k not in self._attributes_extra_blocked
}
self._attr_extra_state_attributes = filtered_dict
get_mqtt_data(self.hass).state_write_requests.write_state_request(
self
)
else:
_LOGGER.warning("JSON result was not a dictionary")
except ValueError:
@ -488,6 +522,7 @@ class MqttAvailability(Entity):
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"available"})
def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message."""
topic = msg.topic
@ -500,8 +535,6 @@ class MqttAvailability(Entity):
self._available[topic] = False
self._available_latest = False
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self._available = {
topic: (self._available[topic] if topic in self._available else False)
for topic in self._avail_topics

View file

@ -45,6 +45,7 @@ from .mixins import (
MqttAvailability,
MqttEntity,
async_setup_entry_helper,
write_state_on_attr_change,
)
from .models import (
MqttValueTemplate,
@ -52,7 +53,6 @@ from .models import (
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__)
@ -287,13 +287,13 @@ class MqttSensor(MqttEntity, RestoreSensor):
)
@callback
@write_state_on_attr_change(self, {"_attr_native_value", "_attr_last_reset"})
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
_update_state(msg)
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
_update_last_reset(msg)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC],

View file

@ -47,6 +47,7 @@ from .test_common import (
help_test_reloadable,
help_test_setting_attribute_via_mqtt_json_message,
help_test_setting_attribute_with_template,
help_test_skipped_async_ha_write_state,
help_test_unique_id,
help_test_unload_config_entry_with_platform,
help_test_update_with_json_attrs_bad_json,
@ -1248,3 +1249,38 @@ async def test_entity_name(
await help_test_entity_name(
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
)
@pytest.mark.parametrize(
"hass_config",
[
help_custom_config(
binary_sensor.DOMAIN,
DEFAULT_CONFIG,
(
{
"availability_topic": "availability-topic",
"json_attributes_topic": "json-attributes-topic",
},
),
)
],
)
@pytest.mark.parametrize(
("topic", "payload1", "payload2"),
[
("test-topic", "ON", "OFF"),
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
],
)
async def test_skipped_async_ha_write_state(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
topic: str,
payload1: str,
payload2: str,
) -> None:
"""Test a write state command is only called when there is change."""
await mqtt_mock_entry()
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)

View file

@ -1925,3 +1925,28 @@ async def help_test_discovery_setup(
await hass.async_block_till_done()
state = hass.states.get(f"{domain}.{name}")
assert state and state.state is not None
async def help_test_skipped_async_ha_write_state(
hass: HomeAssistant, topic: str, payload1: str, payload2: str
) -> None:
"""Test entity.async_ha_write_state is only called on changes."""
with patch(
"homeassistant.components.mqtt.mixins.MqttEntity.async_write_ha_state"
) as mock_async_ha_write_state:
assert len(mock_async_ha_write_state.mock_calls) == 0
async_fire_mqtt_message(hass, topic, payload1)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 1
async_fire_mqtt_message(hass, topic, payload1)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 1
async_fire_mqtt_message(hass, topic, payload2)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 2
async_fire_mqtt_message(hass, topic, payload2)
await hass.async_block_till_done()
assert len(mock_async_ha_write_state.mock_calls) == 2

View file

@ -60,6 +60,7 @@ from .test_common import (
help_test_setting_attribute_via_mqtt_json_message,
help_test_setting_attribute_with_template,
help_test_setting_blocked_attribute_via_mqtt_json_message,
help_test_skipped_async_ha_write_state,
help_test_unique_id,
help_test_unload_config_entry_with_platform,
help_test_update_with_json_attrs_bad_json,
@ -1437,3 +1438,45 @@ async def test_entity_name(
await help_test_entity_name(
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
)
@pytest.mark.parametrize(
"hass_config",
[
help_custom_config(
sensor.DOMAIN,
DEFAULT_CONFIG,
(
{
"availability_topic": "availability-topic",
"json_attributes_topic": "json-attributes-topic",
"value_template": "{{ value_json.state }}",
"last_reset_value_template": "{{ value_json.last_reset }}",
},
),
)
],
)
@pytest.mark.parametrize(
("topic", "payload1", "payload2"),
[
("test-topic", '{"state":"val1"}', '{"state":"val2"}'),
(
"test-topic",
'{"last_reset":"2023-09-15 15:11:03"}',
'{"last_reset":"2023-09-16 15:11:02"}',
),
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
],
)
async def test_skipped_async_ha_write_state(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
topic: str,
payload1: str,
payload2: str,
) -> None:
"""Test a write state command is only called when there is change."""
await mqtt_mock_entry()
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)