Create bound callback_message_received method for handling mqtt callbacks (#117951)

* Create bound callback_message_received method for handling mqtt callbacks

* refactor a bit

* fix ruff

* reduce overhead

* cleanup

* cleanup

* Revert changes alarm_control_panel

* Add sensor and binary sensor

* use same pattern for MqttAttributes/MqttAvailability

* remove useless function since we did not need to add to it

* code cleanup

* collapse

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Jan Bouwhuis 2024-05-24 11:18:25 +02:00 committed by GitHub
parent d4df86da06
commit 9333965b23
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 210 additions and 152 deletions

View file

@ -3,6 +3,7 @@
from __future__ import annotations
from datetime import datetime, timedelta
from functools import partial
import logging
from typing import Any
@ -37,13 +38,7 @@ from homeassistant.util import dt as dt_util
from . import subscription
from .config import MQTT_RO_SCHEMA
from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE
from .debug_info import log_messages
from .mixins import (
MqttAvailability,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttAvailability, MqttEntity, async_setup_entity_entry_helper
from .models import MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@ -162,21 +157,16 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
entity=self,
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
def off_delay_listener(now: datetime) -> None:
def _off_delay_listener(self, now: datetime) -> None:
"""Switch device off after a delay."""
self._delay_listener = None
self._attr_is_on = False
self.async_write_ha_state()
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on", "_expired"})
def state_message_received(msg: ReceiveMessage) -> None:
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle a new received MQTT state message."""
# auto-expire enabled?
if self._expire_after:
# When expire_after is set, and we receive a message, assume device is
@ -238,16 +228,24 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
off_delay: int | None = self._config.get(CONF_OFF_DELAY)
if self._attr_is_on and off_delay is not None:
self._delay_listener = evt.async_call_later(
self.hass, off_delay, off_delay_listener
self.hass, off_delay, self._off_delay_listener
)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._config[CONF_STATE_TOPIC],
"msg_callback": state_message_received,
"msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_is_on", "_expired"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}

View file

@ -86,9 +86,12 @@ def add_subscription(
hass: HomeAssistant,
message_callback: MessageCallbackType,
subscription: str,
entity_id: str | None = None,
) -> None:
"""Prepare debug data for subscription."""
if entity_id := getattr(message_callback, "__entity_id", None):
if not entity_id:
entity_id = getattr(message_callback, "__entity_id", None)
if entity_id:
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
)
@ -104,9 +107,12 @@ def remove_subscription(
hass: HomeAssistant,
message_callback: MessageCallbackType,
subscription: str,
entity_id: str | None = None,
) -> None:
"""Remove debug data for subscription if it exists."""
if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in (
if not entity_id:
entity_id = getattr(message_callback, "__entity_id", None)
if entity_id and entity_id in (
debug_info_entities := hass.data[DATA_MQTT].debug_info_entities
):
debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1

View file

@ -48,6 +48,7 @@ from homeassistant.helpers.event import (
async_track_entity_registry_updated_event,
)
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import (
UNDEFINED,
ConfigType,
@ -93,7 +94,7 @@ from .const import (
MQTT_CONNECTED,
MQTT_DISCONNECTED,
)
from .debug_info import log_message, log_messages
from .debug_info import log_message
from .discovery import (
MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW,
@ -401,6 +402,7 @@ class MqttAttributes(Entity):
"""Mixin used for platforms that support JSON attributes."""
_attributes_extra_blocked: frozenset[str] = frozenset()
_attr_tpl: Callable[[ReceivePayloadType], ReceivePayloadType] | None = None
def __init__(self, config: ConfigType) -> None:
"""Initialize the JSON attributes mixin."""
@ -424,38 +426,21 @@ class MqttAttributes(Entity):
def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
attr_tpl = MqttValueTemplate(
self._attr_tpl = MqttValueTemplate(
self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_extra_state_attributes"})
def attributes_message_received(msg: ReceiveMessage) -> None:
"""Update extra state attributes."""
payload = attr_tpl(msg.payload)
try:
json_dict = json_loads(payload) if isinstance(payload, str) else None
if isinstance(json_dict, dict):
filtered_dict = {
k: v
for k, v in json_dict.items()
if k not in MQTT_ATTRIBUTES_BLOCKED
and k not in self._attributes_extra_blocked
}
self._attr_extra_state_attributes = filtered_dict
else:
_LOGGER.warning("JSON result was not a dictionary")
except ValueError:
_LOGGER.warning("Erroneous JSON: %s", payload)
self._attributes_sub_state = async_prepare_subscribe_topics(
self.hass,
self._attributes_sub_state,
{
CONF_JSON_ATTRS_TOPIC: {
"topic": self._attributes_config.get(CONF_JSON_ATTRS_TOPIC),
"msg_callback": attributes_message_received,
"msg_callback": partial(
self._message_callback, # type: ignore[attr-defined]
self._attributes_message_received,
{"_attr_extra_state_attributes"},
),
"entity_id": self.entity_id,
"qos": self._attributes_config.get(CONF_QOS),
"encoding": self._attributes_config[CONF_ENCODING] or None,
}
@ -472,6 +457,28 @@ class MqttAttributes(Entity):
self.hass, self._attributes_sub_state
)
@callback
def _attributes_message_received(self, msg: ReceiveMessage) -> None:
"""Update extra state attributes."""
if TYPE_CHECKING:
assert self._attr_tpl is not None
payload = self._attr_tpl(msg.payload)
try:
json_dict = json_loads(payload) if isinstance(payload, str) else None
except ValueError:
_LOGGER.warning("Erroneous JSON: %s", payload)
else:
if isinstance(json_dict, dict):
filtered_dict = {
k: v
for k, v in json_dict.items()
if k not in MQTT_ATTRIBUTES_BLOCKED
and k not in self._attributes_extra_blocked
}
self._attr_extra_state_attributes = filtered_dict
else:
_LOGGER.warning("JSON result was not a dictionary")
class MqttAvailability(Entity):
"""Mixin used for platforms that report availability."""
@ -535,28 +542,18 @@ class MqttAvailability(Entity):
def _availability_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"available"})
def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message."""
topic = msg.topic
payload = self._avail_topics[topic][CONF_AVAILABILITY_TEMPLATE](msg.payload)
if payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:
self._available[topic] = True
self._available_latest = True
elif payload == self._avail_topics[topic][CONF_PAYLOAD_NOT_AVAILABLE]:
self._available[topic] = False
self._available_latest = False
self._available = {
topic: (self._available.get(topic, False)) for topic in self._avail_topics
}
topics: dict[str, dict[str, Any]] = {
f"availability_{topic}": {
"topic": topic,
"msg_callback": availability_message_received,
"msg_callback": partial(
self._message_callback, # type: ignore[attr-defined]
self._availability_message_received,
{"available"},
),
"entity_id": self.entity_id,
"qos": self._avail_config[CONF_QOS],
"encoding": self._avail_config[CONF_ENCODING] or None,
}
@ -569,6 +566,19 @@ class MqttAvailability(Entity):
topics,
)
@callback
def _availability_message_received(self, msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message."""
topic = msg.topic
avail_topic = self._avail_topics[topic]
payload = avail_topic[CONF_AVAILABILITY_TEMPLATE](msg.payload)
if payload == avail_topic[CONF_PAYLOAD_AVAILABLE]:
self._available[topic] = True
self._available_latest = True
elif payload == avail_topic[CONF_PAYLOAD_NOT_AVAILABLE]:
self._available[topic] = False
self._available_latest = False
async def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state)
@ -1073,6 +1083,7 @@ class MqttEntity(
):
"""Representation of an MQTT entity."""
_attr_force_update = False
_attr_has_entity_name = True
_attr_should_poll = False
_default_name: str | None
@ -1225,6 +1236,45 @@ class MqttEntity(
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
def _attrs_have_changed(
self, attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...]
) -> bool:
"""Return True if attributes on entity changed or if update is forced."""
if self._attr_force_update:
return True
for attribute, last_value in attrs_snapshot:
if getattr(self, attribute, UNDEFINED) != last_value:
return True
return False
@callback
def _message_callback(
self,
msg_callback: MessageCallbackType,
attributes: set[str],
msg: ReceiveMessage,
) -> None:
"""Process the message callback."""
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes
)
mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
msg.subscribed_topic
]["messages"]
if msg not in messages:
messages.append(msg)
try:
msg_callback(msg)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self)
def update_device(
hass: HomeAssistant,

View file

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable
from datetime import datetime, timedelta
from functools import partial
import logging
from typing import Any
@ -40,13 +41,7 @@ from homeassistant.util import dt as dt_util
from . import subscription
from .config import MQTT_RO_SCHEMA
from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE
from .debug_info import log_messages
from .mixins import (
MqttAvailability,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttAvailability, MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttValueTemplate,
PayloadSentinel,
@ -215,9 +210,9 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
def _update_state(msg: ReceiveMessage) -> None:
# auto-expire enabled?
@ -280,20 +275,22 @@ class MqttSensor(MqttEntity, RestoreSensor):
"Invalid last_reset message '%s' from '%s'", msg.payload, msg.topic
)
@callback
@write_state_on_attr_change(
self, {"_attr_native_value", "_attr_last_reset", "_expired"}
)
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
_update_state(msg)
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
_update_last_reset(msg)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received,
"msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_native_value", "_attr_last_reset", "_expired"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}

View file

@ -26,6 +26,7 @@ class EntitySubscription:
unsubscribe_callback: Callable[[], None] | None = attr.ib()
qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8")
entity_id: str | None = attr.ib(default=None)
def resubscribe_if_necessary(
self, hass: HomeAssistant, other: EntitySubscription | None
@ -41,7 +42,7 @@ class EntitySubscription:
other.unsubscribe_callback()
# Clear debug data if it exists
debug_info.remove_subscription(
self.hass, other.message_callback, str(other.topic)
self.hass, other.message_callback, str(other.topic), other.entity_id
)
if self.topic is None:
@ -49,7 +50,9 @@ class EntitySubscription:
return
# Prepare debug data
debug_info.add_subscription(self.hass, self.message_callback, self.topic)
debug_info.add_subscription(
self.hass, self.message_callback, self.topic, self.entity_id
)
self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding
@ -80,7 +83,7 @@ class EntitySubscription:
def async_prepare_subscribe_topics(
hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None,
topics: dict[str, Any],
topics: dict[str, dict[str, Any]],
) -> dict[str, EntitySubscription]:
"""Prepare (re)subscribe to a set of MQTT topics.
@ -106,6 +109,7 @@ def async_prepare_subscribe_topics(
encoding=value.get("encoding", "utf-8"),
hass=hass,
subscribe_task=None,
entity_id=value.get("entity_id", None),
)
# Get the current subscription state
current = current_subscriptions.pop(key, None)
@ -118,7 +122,10 @@ def async_prepare_subscribe_topics(
remaining.unsubscribe_callback()
# Clear debug data if it exists
debug_info.remove_subscription(
hass, remaining.message_callback, str(remaining.topic)
hass,
remaining.message_callback,
str(remaining.topic),
remaining.entity_id,
)
return new_state