From 3da4815522b8909d4103a24397a21b188fcea0d1 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Mon, 25 Sep 2023 17:59:33 +0200 Subject: [PATCH] Avoid redundant calls to async_write_ha_state for mqtt fan (#100777) Avoid redundant calls to async_write_ha_state --- homeassistant/components/mqtt/fan.py | 21 +++++++------ tests/components/mqtt/test_fan.py | 46 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index 5c7557c7598..5375fa5afc2 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -50,7 +50,12 @@ from .const import ( PAYLOAD_NONE, ) from .debug_info import log_messages -from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper +from .mixins import ( + MQTT_ENTITY_COMMON_SCHEMA, + MqttEntity, + async_setup_entry_helper, + write_state_on_attr_change, +) from .models import ( MessageCallbackType, MqttCommandTemplate, @@ -59,7 +64,7 @@ from .models import ( ReceiveMessage, ReceivePayloadType, ) -from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic +from .util import valid_publish_topic, valid_subscribe_topic CONF_DIRECTION_STATE_TOPIC = "direction_state_topic" CONF_DIRECTION_COMMAND_TOPIC = "direction_command_topic" @@ -367,6 +372,7 @@ class MqttFan(MqttEntity, FanEntity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_is_on"}) def state_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message.""" payload = self._value_templates[CONF_STATE](msg.payload) @@ -379,12 +385,12 @@ class MqttFan(MqttEntity, FanEntity): self._attr_is_on = False elif payload == PAYLOAD_NONE: self._attr_is_on = None - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) add_subscribe_topic(CONF_STATE_TOPIC, state_received) @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_percentage"}) def percentage_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for the percentage.""" rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE]( @@ -395,7 +401,6 @@ class MqttFan(MqttEntity, FanEntity): return if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]: self._attr_percentage = None - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) return try: percentage = ranged_value_to_percentage( @@ -424,18 +429,17 @@ class MqttFan(MqttEntity, FanEntity): ) return self._attr_percentage = percentage - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received) @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_preset_mode"}) def preset_mode_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for preset mode.""" preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload)) if preset_mode == self._payload["PRESET_MODE_RESET"]: self._attr_preset_mode = None - self.async_write_ha_state() return if not preset_mode: _LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic) @@ -450,12 +454,12 @@ class MqttFan(MqttEntity, FanEntity): return self._attr_preset_mode = preset_mode - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received) @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_oscillating"}) def oscillation_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for the oscillation.""" payload = self._value_templates[ATTR_OSCILLATING](msg.payload) @@ -466,13 +470,13 @@ class MqttFan(MqttEntity, FanEntity): self._attr_oscillating = True elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]: self._attr_oscillating = False - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received): self._attr_oscillating = False @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_current_direction"}) def direction_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for the direction.""" direction = self._value_templates[ATTR_DIRECTION](msg.payload) @@ -480,7 +484,6 @@ class MqttFan(MqttEntity, FanEntity): _LOGGER.debug("Ignoring empty direction from '%s'", msg.topic) return self._attr_current_direction = str(direction) - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received) diff --git a/tests/components/mqtt/test_fan.py b/tests/components/mqtt/test_fan.py index 803a0d74766..fe354817aef 100644 --- a/tests/components/mqtt/test_fan.py +++ b/tests/components/mqtt/test_fan.py @@ -60,6 +60,7 @@ from .test_common import ( help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_with_template, help_test_setting_blocked_attribute_via_mqtt_json_message, + help_test_skipped_async_ha_write_state, help_test_unique_id, help_test_unload_config_entry_with_platform, help_test_update_with_json_attrs_bad_json, @@ -2244,3 +2245,48 @@ async def test_unload_entry( await help_test_unload_config_entry_with_platform( hass, mqtt_mock_entry, domain, config ) + + +@pytest.mark.parametrize( + "hass_config", + [ + help_custom_config( + fan.DOMAIN, + DEFAULT_CONFIG, + ( + { + "availability_topic": "availability-topic", + "json_attributes_topic": "json-attributes-topic", + "direction_state_topic": "direction-state-topic", + "percentage_state_topic": "percentage-state-topic", + "preset_mode_command_topic": "preset-mode-command-topic", + "preset_mode_state_topic": "preset-mode-state-topic", + "preset_modes": ["eco", "silent"], + "oscillation_state_topic": "oscillation-state-topic", + }, + ), + ) + ], +) +@pytest.mark.parametrize( + ("topic", "payload1", "payload2"), + [ + ("availability-topic", "online", "offline"), + ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), + ("state-topic", "ON", "OFF"), + ("direction-state-topic", "forward", "reverse"), + ("percentage-state-topic", "30", "40"), + ("preset-mode-state-topic", "eco", "silent"), + ("oscillation-state-topic", "oscillate_on", "oscillate_off"), + ], +) +async def test_skipped_async_ha_write_state( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + topic: str, + payload1: str, + payload2: str, +) -> None: + """Test a write state command is only called when there is change.""" + await mqtt_mock_entry() + await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)