Refactor mqtt callbacks for text (#118130)

This commit is contained in:
Jan Bouwhuis 2024-05-25 23:16:54 +02:00 committed by GitHub
parent e740e2cdc1
commit 6b1b15ef9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import re import re
from typing import Any from typing import Any
@ -34,12 +35,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity):
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
self._attr_assumed_state = bool(self._optimistic) self._attr_assumed_state = bool(self._optimistic)
@callback
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = str(self._value_template(msg.payload))
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return
self._attr_native_value = payload
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, Any] = {} topics: dict[str, Any] = {}
def add_subscription( def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None: ) -> None:
if self._config.get(topic) is not None: if self._config.get(topic) is not None:
topics[topic] = { topics[topic] = {
"topic": self._config[topic], "topic": self._config[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@callback add_subscription(
@log_messages(self.hass, self.entity_id) topics,
@write_state_on_attr_change(self, {"_attr_native_value"}) CONF_STATE_TOPIC,
def handle_state_message_received(msg: ReceiveMessage) -> None: self._handle_state_message_received,
"""Handle receiving state message via MQTT.""" {"_attr_native_value"},
payload = str(self._value_template(msg.payload)) )
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return
self._attr_native_value = payload
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics