Avoid redundant calls to async_write_ha_state for mqtt fan (#100777)

Avoid redundant calls to async_write_ha_state
This commit is contained in:
Jan Bouwhuis 2023-09-25 17:59:33 +02:00 committed by GitHub
parent 84451e858e
commit 3da4815522
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 9 deletions

View file

@ -50,7 +50,12 @@ from .const import (
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper
from .mixins import (
MQTT_ENTITY_COMMON_SCHEMA,
MqttEntity,
async_setup_entry_helper,
write_state_on_attr_change,
)
from .models import (
MessageCallbackType,
MqttCommandTemplate,
@ -59,7 +64,7 @@ from .models import (
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
from .util import valid_publish_topic, valid_subscribe_topic
CONF_DIRECTION_STATE_TOPIC = "direction_state_topic"
CONF_DIRECTION_COMMAND_TOPIC = "direction_command_topic"
@ -367,6 +372,7 @@ class MqttFan(MqttEntity, FanEntity):
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
@ -379,12 +385,12 @@ class MqttFan(MqttEntity, FanEntity):
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_STATE_TOPIC, state_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_percentage"})
def percentage_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
@ -395,7 +401,6 @@ class MqttFan(MqttEntity, FanEntity):
return
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
self._attr_percentage = None
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
return
try:
percentage = ranged_value_to_percentage(
@ -424,18 +429,17 @@ class MqttFan(MqttEntity, FanEntity):
)
return
self._attr_percentage = percentage
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_preset_mode"})
def preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode."""
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None
self.async_write_ha_state()
return
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
@ -450,12 +454,12 @@ class MqttFan(MqttEntity, FanEntity):
return
self._attr_preset_mode = preset_mode
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_oscillating"})
def oscillation_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
@ -466,13 +470,13 @@ class MqttFan(MqttEntity, FanEntity):
self._attr_oscillating = True
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
self._attr_oscillating = False
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received):
self._attr_oscillating = False
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_current_direction"})
def direction_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the direction."""
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
@ -480,7 +484,6 @@ class MqttFan(MqttEntity, FanEntity):
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
return
self._attr_current_direction = str(direction)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received)

View file

@ -60,6 +60,7 @@ from .test_common import (
help_test_setting_attribute_via_mqtt_json_message,
help_test_setting_attribute_with_template,
help_test_setting_blocked_attribute_via_mqtt_json_message,
help_test_skipped_async_ha_write_state,
help_test_unique_id,
help_test_unload_config_entry_with_platform,
help_test_update_with_json_attrs_bad_json,
@ -2244,3 +2245,48 @@ async def test_unload_entry(
await help_test_unload_config_entry_with_platform(
hass, mqtt_mock_entry, domain, config
)
@pytest.mark.parametrize(
"hass_config",
[
help_custom_config(
fan.DOMAIN,
DEFAULT_CONFIG,
(
{
"availability_topic": "availability-topic",
"json_attributes_topic": "json-attributes-topic",
"direction_state_topic": "direction-state-topic",
"percentage_state_topic": "percentage-state-topic",
"preset_mode_command_topic": "preset-mode-command-topic",
"preset_mode_state_topic": "preset-mode-state-topic",
"preset_modes": ["eco", "silent"],
"oscillation_state_topic": "oscillation-state-topic",
},
),
)
],
)
@pytest.mark.parametrize(
("topic", "payload1", "payload2"),
[
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
("state-topic", "ON", "OFF"),
("direction-state-topic", "forward", "reverse"),
("percentage-state-topic", "30", "40"),
("preset-mode-state-topic", "eco", "silent"),
("oscillation-state-topic", "oscillate_on", "oscillate_off"),
],
)
async def test_skipped_async_ha_write_state(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
topic: str,
payload1: str,
payload2: str,
) -> None:
"""Test a write state command is only called when there is change."""
await mqtt_mock_entry()
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)