Log an exception mqtt client call back throws (#117028)

* Log an exception mqtt client call back throws

* Supress exceptions and add test
This commit is contained in:
Jan Bouwhuis 2024-05-07 21:19:46 +02:00 committed by GitHub
parent db138f3727
commit a3248ccff9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 5 deletions

View file

@ -495,6 +495,9 @@ class MQTT:
mqttc.on_subscribe = self._async_mqtt_on_callback mqttc.on_subscribe = self._async_mqtt_on_callback
mqttc.on_unsubscribe = self._async_mqtt_on_callback mqttc.on_unsubscribe = self._async_mqtt_on_callback
# suppress exceptions at callback
mqttc.suppress_exceptions = True
if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL): if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL):
will_message = PublishMessage(**will) will_message = PublishMessage(**will)
mqttc.will_set( mqttc.will_set(
@ -989,10 +992,21 @@ class MQTT:
def _async_mqtt_on_message( def _async_mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None: ) -> None:
topic = msg.topic try:
# msg.topic is a property that decodes the topic to a string # msg.topic is a property that decodes the topic to a string
# every time it is accessed. Save the result to avoid # every time it is accessed. Save the result to avoid
# decoding the same topic multiple times. # decoding the same topic multiple times.
topic = msg.topic
except UnicodeDecodeError:
bare_topic: bytes = getattr(msg, "_topic")
_LOGGER.warning(
"Skipping received%s message on invalid topic %s (qos=%s): %s",
" retained" if msg.retain else "",
bare_topic,
msg.qos,
msg.payload[0:8192],
)
return
_LOGGER.debug( _LOGGER.debug(
"Received%s message on %s (qos=%s): %s", "Received%s message on %s (qos=%s): %s",
" retained" if msg.retain else "", " retained" if msg.retain else "",

View file

@ -8,8 +8,9 @@ import json
import logging import logging
import socket import socket
import ssl import ssl
import time
from typing import Any, TypedDict from typing import Any, TypedDict
from unittest.mock import ANY, MagicMock, call, mock_open, patch from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import paho.mqtt.client as paho_mqtt import paho.mqtt.client as paho_mqtt
@ -951,6 +952,42 @@ async def test_receiving_non_utf8_message_gets_logged(
) )
async def test_receiving_message_with_non_utf8_topic_gets_logged(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
record_calls: MessageCallbackType,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test receiving a non utf8 encoded topic."""
await mqtt_mock_entry()
await mqtt.async_subscribe(hass, "test-topic", record_calls)
# Local import to avoid processing MQTT modules when running a testcase
# which does not use MQTT.
# pylint: disable-next=import-outside-toplevel
from paho.mqtt.client import MQTTMessage
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.mqtt.models import MqttData
msg = MQTTMessage(topic=b"tasmota/discovery/18FE34E0B760\xcc\x02")
msg.payload = b"Payload"
msg.qos = 2
msg.retain = True
msg.timestamp = time.monotonic()
mqtt_data: MqttData = hass.data["mqtt"]
assert mqtt_data.client
mqtt_data.client._async_mqtt_on_message(Mock(), None, msg)
assert (
"Skipping received retained message on invalid "
"topic b'tasmota/discovery/18FE34E0B760\\xcc\\x02' "
"(qos=2): b'Payload'" in caplog.text
)
async def test_all_subscriptions_run_when_decode_fails( async def test_all_subscriptions_run_when_decode_fails(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,