Avoid redundant calls to async_ha_write_state
in MQTT (binary) sensor (#100438)
* Only call `async_ha_write_state` on changes. * Make helper class * Use UndefinedType * Remove del * Integrate monitor into MqttEntity * Track extra state attributes and availability * Add `__slots__` * Add monitor to MqttAttributes and MqttAvailability * Write out loop * Add test * Make common test and parameterize * Add test for last_reset attribute * MqttMonitorEntity base class * Rename attr and update docstr `track` method. * correction doct * Implement as a decorator * Move tracking functions into decorator * Rename decorator * Follow up comment
This commit is contained in:
parent
11c4c37cf9
commit
aed3ba3acd
6 changed files with 147 additions and 11 deletions
|
@ -43,9 +43,9 @@ from .mixins import (
|
|||
MqttAvailability,
|
||||
MqttEntity,
|
||||
async_setup_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .models import MqttValueTemplate, ReceiveMessage
|
||||
from .util import get_mqtt_data
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -191,6 +191,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
||||
def state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle a new received MQTT state message."""
|
||||
# auto-expire enabled?
|
||||
|
@ -257,8 +258,6 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
|||
self.hass, off_delay, off_delay_listener
|
||||
)
|
||||
|
||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast, final
|
||||
|
||||
|
@ -101,6 +101,7 @@ from .discovery import (
|
|||
set_discovery_hash,
|
||||
)
|
||||
from .models import (
|
||||
MessageCallbackType,
|
||||
MqttValueTemplate,
|
||||
PublishPayloadType,
|
||||
ReceiveMessage,
|
||||
|
@ -346,6 +347,41 @@ def init_entity_id_from_config(
|
|||
)
|
||||
|
||||
|
||||
def write_state_on_attr_change(
|
||||
entity: Entity, attributes: set[str]
|
||||
) -> Callable[[MessageCallbackType], MessageCallbackType]:
|
||||
"""Wrap an MQTT message callback to track state attribute changes."""
|
||||
|
||||
def _attrs_have_changed(tracked_attrs: dict[str, Any]) -> bool:
|
||||
"""Return True if attributes on entity changed or if update is forced."""
|
||||
if not (write_state := (getattr(entity, "_attr_force_update", False))):
|
||||
for attribute, last_value in tracked_attrs.items():
|
||||
if getattr(entity, attribute, UNDEFINED) != last_value:
|
||||
write_state = True
|
||||
break
|
||||
|
||||
return write_state
|
||||
|
||||
def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Track attributes for write state requests."""
|
||||
tracked_attrs: dict[str, Any] = {
|
||||
attribute: getattr(entity, attribute, UNDEFINED)
|
||||
for attribute in attributes
|
||||
}
|
||||
msg_callback(msg)
|
||||
if not _attrs_have_changed(tracked_attrs):
|
||||
return
|
||||
|
||||
mqtt_data = get_mqtt_data(entity.hass)
|
||||
mqtt_data.state_write_requests.write_state_request(entity)
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class MqttAttributes(Entity):
|
||||
"""Mixin used for platforms that support JSON attributes."""
|
||||
|
||||
|
@ -379,6 +415,7 @@ class MqttAttributes(Entity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_extra_state_attributes"})
|
||||
def attributes_message_received(msg: ReceiveMessage) -> None:
|
||||
try:
|
||||
payload = attr_tpl(msg.payload)
|
||||
|
@ -391,9 +428,6 @@ class MqttAttributes(Entity):
|
|||
and k not in self._attributes_extra_blocked
|
||||
}
|
||||
self._attr_extra_state_attributes = filtered_dict
|
||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(
|
||||
self
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning("JSON result was not a dictionary")
|
||||
except ValueError:
|
||||
|
@ -488,6 +522,7 @@ class MqttAvailability(Entity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"available"})
|
||||
def availability_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle a new received MQTT availability message."""
|
||||
topic = msg.topic
|
||||
|
@ -500,8 +535,6 @@ class MqttAvailability(Entity):
|
|||
self._available[topic] = False
|
||||
self._available_latest = False
|
||||
|
||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
||||
|
||||
self._available = {
|
||||
topic: (self._available[topic] if topic in self._available else False)
|
||||
for topic in self._avail_topics
|
||||
|
|
|
@ -45,6 +45,7 @@ from .mixins import (
|
|||
MqttAvailability,
|
||||
MqttEntity,
|
||||
async_setup_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .models import (
|
||||
MqttValueTemplate,
|
||||
|
@ -52,7 +53,6 @@ from .models import (
|
|||
ReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from .util import get_mqtt_data
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -287,13 +287,13 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
|||
)
|
||||
|
||||
@callback
|
||||
@write_state_on_attr_change(self, {"_attr_native_value", "_attr_last_reset"})
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
_update_state(msg)
|
||||
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
|
||||
_update_last_reset(msg)
|
||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
||||
|
||||
topics["state_topic"] = {
|
||||
"topic": self._config[CONF_STATE_TOPIC],
|
||||
|
|
|
@ -47,6 +47,7 @@ from .test_common import (
|
|||
help_test_reloadable,
|
||||
help_test_setting_attribute_via_mqtt_json_message,
|
||||
help_test_setting_attribute_with_template,
|
||||
help_test_skipped_async_ha_write_state,
|
||||
help_test_unique_id,
|
||||
help_test_unload_config_entry_with_platform,
|
||||
help_test_update_with_json_attrs_bad_json,
|
||||
|
@ -1248,3 +1249,38 @@ async def test_entity_name(
|
|||
await help_test_entity_name(
|
||||
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hass_config",
|
||||
[
|
||||
help_custom_config(
|
||||
binary_sensor.DOMAIN,
|
||||
DEFAULT_CONFIG,
|
||||
(
|
||||
{
|
||||
"availability_topic": "availability-topic",
|
||||
"json_attributes_topic": "json-attributes-topic",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("topic", "payload1", "payload2"),
|
||||
[
|
||||
("test-topic", "ON", "OFF"),
|
||||
("availability-topic", "online", "offline"),
|
||||
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
|
||||
],
|
||||
)
|
||||
async def test_skipped_async_ha_write_state(
|
||||
hass: HomeAssistant,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
topic: str,
|
||||
payload1: str,
|
||||
payload2: str,
|
||||
) -> None:
|
||||
"""Test a write state command is only called when there is change."""
|
||||
await mqtt_mock_entry()
|
||||
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
|
||||
|
|
|
@ -1925,3 +1925,28 @@ async def help_test_discovery_setup(
|
|||
await hass.async_block_till_done()
|
||||
state = hass.states.get(f"{domain}.{name}")
|
||||
assert state and state.state is not None
|
||||
|
||||
|
||||
async def help_test_skipped_async_ha_write_state(
|
||||
hass: HomeAssistant, topic: str, payload1: str, payload2: str
|
||||
) -> None:
|
||||
"""Test entity.async_ha_write_state is only called on changes."""
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.mixins.MqttEntity.async_write_ha_state"
|
||||
) as mock_async_ha_write_state:
|
||||
assert len(mock_async_ha_write_state.mock_calls) == 0
|
||||
async_fire_mqtt_message(hass, topic, payload1)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_async_ha_write_state.mock_calls) == 1
|
||||
|
||||
async_fire_mqtt_message(hass, topic, payload1)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_async_ha_write_state.mock_calls) == 1
|
||||
|
||||
async_fire_mqtt_message(hass, topic, payload2)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_async_ha_write_state.mock_calls) == 2
|
||||
|
||||
async_fire_mqtt_message(hass, topic, payload2)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_async_ha_write_state.mock_calls) == 2
|
||||
|
|
|
@ -60,6 +60,7 @@ from .test_common import (
|
|||
help_test_setting_attribute_via_mqtt_json_message,
|
||||
help_test_setting_attribute_with_template,
|
||||
help_test_setting_blocked_attribute_via_mqtt_json_message,
|
||||
help_test_skipped_async_ha_write_state,
|
||||
help_test_unique_id,
|
||||
help_test_unload_config_entry_with_platform,
|
||||
help_test_update_with_json_attrs_bad_json,
|
||||
|
@ -1437,3 +1438,45 @@ async def test_entity_name(
|
|||
await help_test_entity_name(
|
||||
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hass_config",
|
||||
[
|
||||
help_custom_config(
|
||||
sensor.DOMAIN,
|
||||
DEFAULT_CONFIG,
|
||||
(
|
||||
{
|
||||
"availability_topic": "availability-topic",
|
||||
"json_attributes_topic": "json-attributes-topic",
|
||||
"value_template": "{{ value_json.state }}",
|
||||
"last_reset_value_template": "{{ value_json.last_reset }}",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("topic", "payload1", "payload2"),
|
||||
[
|
||||
("test-topic", '{"state":"val1"}', '{"state":"val2"}'),
|
||||
(
|
||||
"test-topic",
|
||||
'{"last_reset":"2023-09-15 15:11:03"}',
|
||||
'{"last_reset":"2023-09-16 15:11:02"}',
|
||||
),
|
||||
("availability-topic", "online", "offline"),
|
||||
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
|
||||
],
|
||||
)
|
||||
async def test_skipped_async_ha_write_state(
|
||||
hass: HomeAssistant,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
topic: str,
|
||||
payload1: str,
|
||||
payload2: str,
|
||||
) -> None:
|
||||
"""Test a write state command is only called when there is change."""
|
||||
await mqtt_mock_entry()
|
||||
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue