Rework mqtt callbacks for camera, image and event (#118109)

This commit is contained in:
Jan Bouwhuis 2024-05-25 23:23:45 +02:00 committed by GitHub
parent ae0c00218a
commit f21c0679b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 148 additions and 143 deletions

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from base64 import b64decode from base64 import b64decode
from functools import partial
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_QOS, CONF_TOPIC from .const import CONF_QOS, CONF_TOPIC
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ReceiveMessage from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@ -97,12 +97,8 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _image_received(self, msg: ReceiveMessage) -> None:
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload) self._last_image = b64decode(msg.payload)
@ -111,13 +107,21 @@ class MqttCamera(MqttEntity, Camera):
assert isinstance(msg.payload, bytes) assert isinstance(msg.payload, bytes)
self._last_image = msg.payload self._last_image = msg.payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config[CONF_TOPIC], "topic": self._config[CONF_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._image_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": None, "encoding": None,
} }

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@ -31,7 +32,6 @@ from .const import (
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@ -113,13 +113,8 @@ class MqttEvent(MqttEntity, EventEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
@callback @callback
@log_messages(self.hass, self.entity_id) def _event_received(self, msg: ReceiveMessage) -> None:
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
if msg.retain: if msg.retain:
_LOGGER.debug( _LOGGER.debug(
@ -161,10 +156,7 @@ class MqttEvent(MqttEntity, EventEntity):
) )
except KeyError: except KeyError:
_LOGGER.warning( _LOGGER.warning(
( ("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
"`event_type` missing in JSON event payload, "
" '%s' on topic %s"
),
payload, payload,
msg.topic, msg.topic,
) )
@ -194,9 +186,18 @@ class MqttEvent(MqttEntity, EventEntity):
mqtt_data = self.hass.data[DATA_MQTT] mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._event_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }

View file

@ -5,6 +5,7 @@ from __future__ import annotations
from base64 import b64decode from base64 import b64decode
import binascii import binascii
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_ENCODING, CONF_QOS from .const import CONF_ENCODING, CONF_QOS
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@ -143,31 +143,8 @@ class MqttImage(MqttEntity, ImageEntity):
config.get(CONF_URL_TEMPLATE), entity=self config.get(CONF_URL_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
"""Add a topic to subscribe to."""
encoding: str | None
encoding = (
None
if CONF_IMAGE_TOPIC in self._config
else self._config[CONF_ENCODING] or None
)
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"qos": self._config[CONF_QOS],
"encoding": encoding,
}
return has_topic
@callback @callback
@log_messages(self.hass, self.entity_id) def _image_data_received(self, msg: ReceiveMessage) -> None:
def image_data_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
try: try:
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
@ -186,11 +163,8 @@ class MqttImage(MqttEntity, ImageEntity):
self._attr_image_last_updated = dt_util.utcnow() self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
def image_from_url_request_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
try: try:
url = cv.url(self._url_template(msg.payload)) url = cv.url(self._url_template(msg.payload))
@ -208,7 +182,31 @@ class MqttImage(MqttEntity, ImageEntity):
self._cached_image = None self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received) def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
"""Add a topic to subscribe to."""
encoding: str | None
encoding = (
None
if CONF_IMAGE_TOPIC in self._config
else self._config[CONF_ENCODING] or None
)
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": partial(self._message_callback, msg_callback, None),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": encoding,
}
return has_topic
add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics

View file

@ -1254,12 +1254,14 @@ class MqttEntity(
def _message_callback( def _message_callback(
self, self,
msg_callback: MessageCallbackType, msg_callback: MessageCallbackType,
attributes: set[str], attributes: set[str] | None,
msg: ReceiveMessage, msg: ReceiveMessage,
) -> None: ) -> None:
"""Process the message callback.""" """Process the message callback."""
if attributes is not None:
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes (attribute, getattr(self, attribute, UNDEFINED))
for attribute in attributes
) )
mqtt_data = self.hass.data[DATA_MQTT] mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][ messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
@ -1274,7 +1276,7 @@ class MqttEntity(
_LOGGER.warning(exc) _LOGGER.warning(exc)
return return
if self._attrs_have_changed(attrs_snapshot): if attributes is not None and self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)