Avoid redundant calls to async_ha_write_state
in MQTT (binary) sensor (#100438)
* Only call `async_ha_write_state` on changes. * Make helper class * Use UndefinedType * Remove del * Integrate monitor into MqttEntity * Track extra state attributes and availability * Add `__slots__` * Add monitor to MqttAttributes and MqttAvailability * Write out loop * Add test * Make common test and parameterize * Add test for last_reset attribute * MqttMonitorEntity base class * Rename attr and update docstr `track` method. * correction doct * Implement as a decorator * Move tracking functions into decorator * Rename decorator * Follow up comment
This commit is contained in:
parent
11c4c37cf9
commit
aed3ba3acd
6 changed files with 147 additions and 11 deletions
|
@ -43,9 +43,9 @@ from .mixins import (
|
||||||
MqttAvailability,
|
MqttAvailability,
|
||||||
MqttEntity,
|
MqttEntity,
|
||||||
async_setup_entry_helper,
|
async_setup_entry_helper,
|
||||||
|
write_state_on_attr_change,
|
||||||
)
|
)
|
||||||
from .models import MqttValueTemplate, ReceiveMessage
|
from .models import MqttValueTemplate, ReceiveMessage
|
||||||
from .util import get_mqtt_data
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -191,6 +191,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
|
@write_state_on_attr_change(self, {"_attr_is_on"})
|
||||||
def state_message_received(msg: ReceiveMessage) -> None:
|
def state_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle a new received MQTT state message."""
|
"""Handle a new received MQTT state message."""
|
||||||
# auto-expire enabled?
|
# auto-expire enabled?
|
||||||
|
@ -257,8 +258,6 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
||||||
self.hass, off_delay, off_delay_listener
|
self.hass, off_delay, off_delay_listener
|
||||||
)
|
)
|
||||||
|
|
||||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, cast, final
|
from typing import TYPE_CHECKING, Any, Protocol, cast, final
|
||||||
|
|
||||||
|
@ -101,6 +101,7 @@ from .discovery import (
|
||||||
set_discovery_hash,
|
set_discovery_hash,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
|
MessageCallbackType,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
PublishPayloadType,
|
PublishPayloadType,
|
||||||
ReceiveMessage,
|
ReceiveMessage,
|
||||||
|
@ -346,6 +347,41 @@ def init_entity_id_from_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_state_on_attr_change(
|
||||||
|
entity: Entity, attributes: set[str]
|
||||||
|
) -> Callable[[MessageCallbackType], MessageCallbackType]:
|
||||||
|
"""Wrap an MQTT message callback to track state attribute changes."""
|
||||||
|
|
||||||
|
def _attrs_have_changed(tracked_attrs: dict[str, Any]) -> bool:
|
||||||
|
"""Return True if attributes on entity changed or if update is forced."""
|
||||||
|
if not (write_state := (getattr(entity, "_attr_force_update", False))):
|
||||||
|
for attribute, last_value in tracked_attrs.items():
|
||||||
|
if getattr(entity, attribute, UNDEFINED) != last_value:
|
||||||
|
write_state = True
|
||||||
|
break
|
||||||
|
|
||||||
|
return write_state
|
||||||
|
|
||||||
|
def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
|
||||||
|
@wraps(msg_callback)
|
||||||
|
def wrapper(msg: ReceiveMessage) -> None:
|
||||||
|
"""Track attributes for write state requests."""
|
||||||
|
tracked_attrs: dict[str, Any] = {
|
||||||
|
attribute: getattr(entity, attribute, UNDEFINED)
|
||||||
|
for attribute in attributes
|
||||||
|
}
|
||||||
|
msg_callback(msg)
|
||||||
|
if not _attrs_have_changed(tracked_attrs):
|
||||||
|
return
|
||||||
|
|
||||||
|
mqtt_data = get_mqtt_data(entity.hass)
|
||||||
|
mqtt_data.state_write_requests.write_state_request(entity)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
class MqttAttributes(Entity):
|
class MqttAttributes(Entity):
|
||||||
"""Mixin used for platforms that support JSON attributes."""
|
"""Mixin used for platforms that support JSON attributes."""
|
||||||
|
|
||||||
|
@ -379,6 +415,7 @@ class MqttAttributes(Entity):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
|
@write_state_on_attr_change(self, {"_attr_extra_state_attributes"})
|
||||||
def attributes_message_received(msg: ReceiveMessage) -> None:
|
def attributes_message_received(msg: ReceiveMessage) -> None:
|
||||||
try:
|
try:
|
||||||
payload = attr_tpl(msg.payload)
|
payload = attr_tpl(msg.payload)
|
||||||
|
@ -391,9 +428,6 @@ class MqttAttributes(Entity):
|
||||||
and k not in self._attributes_extra_blocked
|
and k not in self._attributes_extra_blocked
|
||||||
}
|
}
|
||||||
self._attr_extra_state_attributes = filtered_dict
|
self._attr_extra_state_attributes = filtered_dict
|
||||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(
|
|
||||||
self
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning("JSON result was not a dictionary")
|
_LOGGER.warning("JSON result was not a dictionary")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -488,6 +522,7 @@ class MqttAvailability(Entity):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
|
@write_state_on_attr_change(self, {"available"})
|
||||||
def availability_message_received(msg: ReceiveMessage) -> None:
|
def availability_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle a new received MQTT availability message."""
|
"""Handle a new received MQTT availability message."""
|
||||||
topic = msg.topic
|
topic = msg.topic
|
||||||
|
@ -500,8 +535,6 @@ class MqttAvailability(Entity):
|
||||||
self._available[topic] = False
|
self._available[topic] = False
|
||||||
self._available_latest = False
|
self._available_latest = False
|
||||||
|
|
||||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
|
||||||
|
|
||||||
self._available = {
|
self._available = {
|
||||||
topic: (self._available[topic] if topic in self._available else False)
|
topic: (self._available[topic] if topic in self._available else False)
|
||||||
for topic in self._avail_topics
|
for topic in self._avail_topics
|
||||||
|
|
|
@ -45,6 +45,7 @@ from .mixins import (
|
||||||
MqttAvailability,
|
MqttAvailability,
|
||||||
MqttEntity,
|
MqttEntity,
|
||||||
async_setup_entry_helper,
|
async_setup_entry_helper,
|
||||||
|
write_state_on_attr_change,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
|
@ -52,7 +53,6 @@ from .models import (
|
||||||
ReceiveMessage,
|
ReceiveMessage,
|
||||||
ReceivePayloadType,
|
ReceivePayloadType,
|
||||||
)
|
)
|
||||||
from .util import get_mqtt_data
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -287,13 +287,13 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@write_state_on_attr_change(self, {"_attr_native_value", "_attr_last_reset"})
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
def message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
_update_state(msg)
|
_update_state(msg)
|
||||||
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
|
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
|
||||||
_update_last_reset(msg)
|
_update_last_reset(msg)
|
||||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
|
||||||
|
|
||||||
topics["state_topic"] = {
|
topics["state_topic"] = {
|
||||||
"topic": self._config[CONF_STATE_TOPIC],
|
"topic": self._config[CONF_STATE_TOPIC],
|
||||||
|
|
|
@ -47,6 +47,7 @@ from .test_common import (
|
||||||
help_test_reloadable,
|
help_test_reloadable,
|
||||||
help_test_setting_attribute_via_mqtt_json_message,
|
help_test_setting_attribute_via_mqtt_json_message,
|
||||||
help_test_setting_attribute_with_template,
|
help_test_setting_attribute_with_template,
|
||||||
|
help_test_skipped_async_ha_write_state,
|
||||||
help_test_unique_id,
|
help_test_unique_id,
|
||||||
help_test_unload_config_entry_with_platform,
|
help_test_unload_config_entry_with_platform,
|
||||||
help_test_update_with_json_attrs_bad_json,
|
help_test_update_with_json_attrs_bad_json,
|
||||||
|
@ -1248,3 +1249,38 @@ async def test_entity_name(
|
||||||
await help_test_entity_name(
|
await help_test_entity_name(
|
||||||
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
|
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"hass_config",
|
||||||
|
[
|
||||||
|
help_custom_config(
|
||||||
|
binary_sensor.DOMAIN,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"availability_topic": "availability-topic",
|
||||||
|
"json_attributes_topic": "json-attributes-topic",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("topic", "payload1", "payload2"),
|
||||||
|
[
|
||||||
|
("test-topic", "ON", "OFF"),
|
||||||
|
("availability-topic", "online", "offline"),
|
||||||
|
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_skipped_async_ha_write_state(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
topic: str,
|
||||||
|
payload1: str,
|
||||||
|
payload2: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test a write state command is only called when there is change."""
|
||||||
|
await mqtt_mock_entry()
|
||||||
|
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
|
||||||
|
|
|
@ -1925,3 +1925,28 @@ async def help_test_discovery_setup(
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
state = hass.states.get(f"{domain}.{name}")
|
state = hass.states.get(f"{domain}.{name}")
|
||||||
assert state and state.state is not None
|
assert state and state.state is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def help_test_skipped_async_ha_write_state(
|
||||||
|
hass: HomeAssistant, topic: str, payload1: str, payload2: str
|
||||||
|
) -> None:
|
||||||
|
"""Test entity.async_ha_write_state is only called on changes."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.mqtt.mixins.MqttEntity.async_write_ha_state"
|
||||||
|
) as mock_async_ha_write_state:
|
||||||
|
assert len(mock_async_ha_write_state.mock_calls) == 0
|
||||||
|
async_fire_mqtt_message(hass, topic, payload1)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(mock_async_ha_write_state.mock_calls) == 1
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, topic, payload1)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(mock_async_ha_write_state.mock_calls) == 1
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, topic, payload2)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(mock_async_ha_write_state.mock_calls) == 2
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, topic, payload2)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(mock_async_ha_write_state.mock_calls) == 2
|
||||||
|
|
|
@ -60,6 +60,7 @@ from .test_common import (
|
||||||
help_test_setting_attribute_via_mqtt_json_message,
|
help_test_setting_attribute_via_mqtt_json_message,
|
||||||
help_test_setting_attribute_with_template,
|
help_test_setting_attribute_with_template,
|
||||||
help_test_setting_blocked_attribute_via_mqtt_json_message,
|
help_test_setting_blocked_attribute_via_mqtt_json_message,
|
||||||
|
help_test_skipped_async_ha_write_state,
|
||||||
help_test_unique_id,
|
help_test_unique_id,
|
||||||
help_test_unload_config_entry_with_platform,
|
help_test_unload_config_entry_with_platform,
|
||||||
help_test_update_with_json_attrs_bad_json,
|
help_test_update_with_json_attrs_bad_json,
|
||||||
|
@ -1437,3 +1438,45 @@ async def test_entity_name(
|
||||||
await help_test_entity_name(
|
await help_test_entity_name(
|
||||||
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
|
hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"hass_config",
|
||||||
|
[
|
||||||
|
help_custom_config(
|
||||||
|
sensor.DOMAIN,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"availability_topic": "availability-topic",
|
||||||
|
"json_attributes_topic": "json-attributes-topic",
|
||||||
|
"value_template": "{{ value_json.state }}",
|
||||||
|
"last_reset_value_template": "{{ value_json.last_reset }}",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("topic", "payload1", "payload2"),
|
||||||
|
[
|
||||||
|
("test-topic", '{"state":"val1"}', '{"state":"val2"}'),
|
||||||
|
(
|
||||||
|
"test-topic",
|
||||||
|
'{"last_reset":"2023-09-15 15:11:03"}',
|
||||||
|
'{"last_reset":"2023-09-16 15:11:02"}',
|
||||||
|
),
|
||||||
|
("availability-topic", "online", "offline"),
|
||||||
|
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_skipped_async_ha_write_state(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
topic: str,
|
||||||
|
payload1: str,
|
||||||
|
payload2: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test a write state command is only called when there is change."""
|
||||||
|
await mqtt_mock_entry()
|
||||||
|
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue