diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index a1341350a7a..505305cad3e 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -43,9 +43,9 @@ from .mixins import ( MqttAvailability, MqttEntity, async_setup_entry_helper, + write_state_on_attr_change, ) from .models import MqttValueTemplate, ReceiveMessage -from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -191,6 +191,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_is_on"}) def state_message_received(msg: ReceiveMessage) -> None: """Handle a new received MQTT state message.""" # auto-expire enabled? @@ -257,8 +258,6 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): self.hass, off_delay, off_delay_listener ) - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) - self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 795eb30e8e2..a01691f0601 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod import asyncio from collections.abc import Callable, Coroutine -from functools import partial +from functools import partial, wraps import logging from typing import TYPE_CHECKING, Any, Protocol, cast, final @@ -101,6 +101,7 @@ from .discovery import ( set_discovery_hash, ) from .models import ( + MessageCallbackType, MqttValueTemplate, PublishPayloadType, ReceiveMessage, @@ -346,6 +347,41 @@ def init_entity_id_from_config( ) +def write_state_on_attr_change( + entity: Entity, attributes: set[str] +) -> Callable[[MessageCallbackType], MessageCallbackType]: + """Wrap an MQTT message callback to track state attribute changes.""" + + def _attrs_have_changed(tracked_attrs: dict[str, Any]) -> bool: + """Return True if attributes on entity changed or if update is forced.""" + if not (write_state := (getattr(entity, "_attr_force_update", False))): + for attribute, last_value in tracked_attrs.items(): + if getattr(entity, attribute, UNDEFINED) != last_value: + write_state = True + break + + return write_state + + def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType: + @wraps(msg_callback) + def wrapper(msg: ReceiveMessage) -> None: + """Track attributes for write state requests.""" + tracked_attrs: dict[str, Any] = { + attribute: getattr(entity, attribute, UNDEFINED) + for attribute in attributes + } + msg_callback(msg) + if not _attrs_have_changed(tracked_attrs): + return + + mqtt_data = get_mqtt_data(entity.hass) + mqtt_data.state_write_requests.write_state_request(entity) + + return wrapper + + return _decorator + + class MqttAttributes(Entity): """Mixin used for platforms that support JSON attributes.""" @@ -379,6 +415,7 @@ class MqttAttributes(Entity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_extra_state_attributes"}) def attributes_message_received(msg: ReceiveMessage) -> None: try: payload = attr_tpl(msg.payload) @@ -391,9 +428,6 @@ class MqttAttributes(Entity): and k not in self._attributes_extra_blocked } self._attr_extra_state_attributes = filtered_dict - get_mqtt_data(self.hass).state_write_requests.write_state_request( - self - ) else: _LOGGER.warning("JSON result was not a dictionary") except ValueError: @@ -488,6 +522,7 @@ class MqttAvailability(Entity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"available"}) def availability_message_received(msg: ReceiveMessage) -> None: """Handle a new received MQTT availability message.""" topic = msg.topic @@ -500,8 +535,6 @@ class MqttAvailability(Entity): self._available[topic] = False self._available_latest = False - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) - self._available = { topic: (self._available[topic] if topic in self._available else False) for topic in self._avail_topics diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 70c8d505b4f..278e70a9737 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -45,6 +45,7 @@ from .mixins import ( MqttAvailability, MqttEntity, async_setup_entry_helper, + write_state_on_attr_change, ) from .models import ( MqttValueTemplate, @@ -52,7 +53,6 @@ from .models import ( ReceiveMessage, ReceivePayloadType, ) -from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -287,13 +287,13 @@ class MqttSensor(MqttEntity, RestoreSensor): ) @callback + @write_state_on_attr_change(self, {"_attr_native_value", "_attr_last_reset"}) @log_messages(self.hass, self.entity_id) def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" _update_state(msg) if CONF_LAST_RESET_VALUE_TEMPLATE in self._config: _update_last_reset(msg) - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) topics["state_topic"] = { "topic": self._config[CONF_STATE_TOPIC], diff --git a/tests/components/mqtt/test_binary_sensor.py b/tests/components/mqtt/test_binary_sensor.py index 91a4833b1fc..ea9c8072290 100644 --- a/tests/components/mqtt/test_binary_sensor.py +++ b/tests/components/mqtt/test_binary_sensor.py @@ -47,6 +47,7 @@ from .test_common import ( help_test_reloadable, help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_with_template, + help_test_skipped_async_ha_write_state, help_test_unique_id, help_test_unload_config_entry_with_platform, help_test_update_with_json_attrs_bad_json, @@ -1248,3 +1249,38 @@ async def test_entity_name( await help_test_entity_name( hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class ) + + +@pytest.mark.parametrize( + "hass_config", + [ + help_custom_config( + binary_sensor.DOMAIN, + DEFAULT_CONFIG, + ( + { + "availability_topic": "availability-topic", + "json_attributes_topic": "json-attributes-topic", + }, + ), + ) + ], +) +@pytest.mark.parametrize( + ("topic", "payload1", "payload2"), + [ + ("test-topic", "ON", "OFF"), + ("availability-topic", "online", "offline"), + ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), + ], +) +async def test_skipped_async_ha_write_state( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + topic: str, + payload1: str, + payload2: str, +) -> None: + """Test a write state command is only called when there is change.""" + await mqtt_mock_entry() + await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2) diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index 9aa88c2d7ba..64bece5369e 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -1925,3 +1925,28 @@ async def help_test_discovery_setup( await hass.async_block_till_done() state = hass.states.get(f"{domain}.{name}") assert state and state.state is not None + + +async def help_test_skipped_async_ha_write_state( + hass: HomeAssistant, topic: str, payload1: str, payload2: str +) -> None: + """Test entity.async_ha_write_state is only called on changes.""" + with patch( + "homeassistant.components.mqtt.mixins.MqttEntity.async_write_ha_state" + ) as mock_async_ha_write_state: + assert len(mock_async_ha_write_state.mock_calls) == 0 + async_fire_mqtt_message(hass, topic, payload1) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 1 + + async_fire_mqtt_message(hass, topic, payload1) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 1 + + async_fire_mqtt_message(hass, topic, payload2) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 2 + + async_fire_mqtt_message(hass, topic, payload2) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 2 diff --git a/tests/components/mqtt/test_sensor.py b/tests/components/mqtt/test_sensor.py index d9c92b315b3..bc75492a03e 100644 --- a/tests/components/mqtt/test_sensor.py +++ b/tests/components/mqtt/test_sensor.py @@ -60,6 +60,7 @@ from .test_common import ( help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_with_template, help_test_setting_blocked_attribute_via_mqtt_json_message, + help_test_skipped_async_ha_write_state, help_test_unique_id, help_test_unload_config_entry_with_platform, help_test_update_with_json_attrs_bad_json, @@ -1437,3 +1438,45 @@ async def test_entity_name( await help_test_entity_name( hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class ) + + +@pytest.mark.parametrize( + "hass_config", + [ + help_custom_config( + sensor.DOMAIN, + DEFAULT_CONFIG, + ( + { + "availability_topic": "availability-topic", + "json_attributes_topic": "json-attributes-topic", + "value_template": "{{ value_json.state }}", + "last_reset_value_template": "{{ value_json.last_reset }}", + }, + ), + ) + ], +) +@pytest.mark.parametrize( + ("topic", "payload1", "payload2"), + [ + ("test-topic", '{"state":"val1"}', '{"state":"val2"}'), + ( + "test-topic", + '{"last_reset":"2023-09-15 15:11:03"}', + '{"last_reset":"2023-09-16 15:11:02"}', + ), + ("availability-topic", "online", "offline"), + ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), + ], +) +async def test_skipped_async_ha_write_state( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + topic: str, + payload1: str, + payload2: str, +) -> None: + """Test a write state command is only called when there is change.""" + await mqtt_mock_entry() + await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)