Remove deprecated callback support for MQTT subscribe (#88543)
* Remove deprecated callback support and fix tests * Add note with removal instruction
This commit is contained in:
parent
24234a55a5
commit
3f79155df6
2 changed files with 61 additions and 166 deletions
|
@ -3,14 +3,14 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from functools import lru_cache, partial, wraps
|
||||
from functools import lru_cache
|
||||
import inspect
|
||||
from itertools import chain, groupby
|
||||
import logging
|
||||
from operator import attrgetter
|
||||
import ssl
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
import async_timeout
|
||||
|
@ -72,7 +72,6 @@ from .models import (
|
|||
PublishMessage,
|
||||
PublishPayloadType,
|
||||
ReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
|
||||
|
||||
|
@ -148,55 +147,11 @@ async def async_publish(
|
|||
)
|
||||
|
||||
|
||||
AsyncDeprecatedMessageCallbackType = Callable[
|
||||
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
|
||||
]
|
||||
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
|
||||
DeprecatedMessageCallbackTypes = (
|
||||
AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType
|
||||
)
|
||||
|
||||
|
||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
||||
def wrap_msg_callback(
|
||||
msg_callback: DeprecatedMessageCallbackTypes,
|
||||
) -> AsyncMessageCallbackType | MessageCallbackType:
|
||||
"""Wrap an MQTT message callback to support deprecated signature."""
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_func = msg_callback
|
||||
while isinstance(check_func, partial):
|
||||
check_func = check_func.func # type: ignore[unreachable]
|
||||
|
||||
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
|
||||
@wraps(msg_callback)
|
||||
async def async_wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
|
||||
msg.topic, msg.payload, msg.qos
|
||||
)
|
||||
|
||||
wrapper_func = async_wrapper
|
||||
return wrapper_func
|
||||
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
||||
|
||||
wrapper_func = wrapper
|
||||
|
||||
return wrapper_func
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_subscribe(
|
||||
hass: HomeAssistant,
|
||||
topic: str,
|
||||
msg_callback: AsyncMessageCallbackType
|
||||
| MessageCallbackType
|
||||
| DeprecatedMessageCallbackTypes,
|
||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
||||
qos: int = DEFAULT_QOS,
|
||||
encoding: str | None = DEFAULT_ENCODING,
|
||||
) -> CALLBACK_TYPE:
|
||||
|
@ -209,8 +164,8 @@ async def async_subscribe(
|
|||
raise HomeAssistantError(
|
||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
|
||||
)
|
||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
||||
# Count callback parameters which don't have a default value
|
||||
# Support for a deprecated callback type was removed with HA core 2023.3.0
|
||||
# The signature validation code can be removed from HA core 2023.5.0
|
||||
non_default = 0
|
||||
if msg_callback:
|
||||
non_default = sum(
|
||||
|
@ -218,26 +173,20 @@ async def async_subscribe(
|
|||
for _, p in inspect.signature(msg_callback).parameters.items()
|
||||
)
|
||||
|
||||
wrapped_msg_callback = msg_callback
|
||||
# If we have 3 parameters with no default value, wrap the callback
|
||||
if non_default == 3:
|
||||
# Check for not supported callback signatures
|
||||
# Can be removed from HA core 2023.5.0
|
||||
if non_default != 1:
|
||||
module = inspect.getmodule(msg_callback)
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Signature of MQTT msg_callback '%s.%s' is deprecated, "
|
||||
"this will stop working with HA core 2023.2"
|
||||
),
|
||||
module.__name__ if module else "<unknown>",
|
||||
msg_callback.__name__,
|
||||
)
|
||||
wrapped_msg_callback = wrap_msg_callback(
|
||||
cast(DeprecatedMessageCallbackTypes, msg_callback)
|
||||
raise HomeAssistantError(
|
||||
"Signature for MQTT msg_callback '{}.{}' is not supported".format(
|
||||
module.__name__ if module else "<unknown>", msg_callback.__name__
|
||||
)
|
||||
)
|
||||
|
||||
async_remove = await mqtt_data.client.async_subscribe(
|
||||
topic,
|
||||
catch_log_exception(
|
||||
wrapped_msg_callback,
|
||||
msg_callback,
|
||||
lambda msg: (
|
||||
f"Exception in {msg_callback.__name__} when handling msg on "
|
||||
f"'{msg.topic}': '{msg.payload}'"
|
||||
|
|
|
@ -920,109 +920,42 @@ async def test_subscribe_bad_topic(
|
|||
await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
||||
async def test_subscribe_deprecated(
|
||||
# Support for a deprecated callback type was removed with HA core 2023.3.0
|
||||
# Test can be removed from HA core 2023.5.0
|
||||
async def test_subscribe_with_deprecated_callback_fails(
|
||||
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
|
||||
) -> None:
|
||||
"""Test the subscription of a topic using deprecated callback signature."""
|
||||
calls: list[tuple[str, ReceivePayloadType, int]]
|
||||
|
||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
||||
"""Test the subscription of a topic using deprecated callback signature fails."""
|
||||
|
||||
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
|
||||
"""Record calls."""
|
||||
calls.append((topic, payload, qos))
|
||||
|
||||
calls = []
|
||||
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0] == "test-topic"
|
||||
assert calls[0][1] == "test-payload"
|
||||
|
||||
unsub()
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
mqtt_mock.async_publish.reset_mock()
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
# Test with partial wrapper
|
||||
calls = []
|
||||
unsub = await mqtt.async_subscribe(
|
||||
hass, "test-topic", RecordCallsPartial(record_calls)
|
||||
)
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0] == "test-topic"
|
||||
assert calls[0][1] == "test-payload"
|
||||
|
||||
unsub()
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await mqtt.async_subscribe(hass, "test-topic", RecordCallsPartial(record_calls))
|
||||
|
||||
|
||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
||||
async def test_subscribe_deprecated_async(
|
||||
# Support for a deprecated callback type was removed with HA core 2023.3.0
|
||||
# Test can be removed from HA core 2023.5.0
|
||||
async def test_subscribe_deprecated_async_fails(
|
||||
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
|
||||
) -> None:
|
||||
"""Test the subscription of a topic using deprecated coroutine signature."""
|
||||
calls: list[tuple[str, ReceivePayloadType, int]]
|
||||
|
||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
||||
"""Test the subscription of a topic using deprecated coroutine signature fails."""
|
||||
|
||||
@callback
|
||||
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
|
||||
"""Record calls."""
|
||||
calls.append((topic, payload, qos))
|
||||
|
||||
calls = []
|
||||
unsub = await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0] == "test-topic"
|
||||
assert calls[0][1] == "test-payload"
|
||||
|
||||
unsub()
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
mqtt_mock.async_publish.reset_mock()
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
|
||||
|
||||
# Test with partial wrapper
|
||||
calls = []
|
||||
unsub = await mqtt.async_subscribe(
|
||||
hass, "test-topic", RecordCallsPartial(async_record_calls)
|
||||
)
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0] == "test-topic"
|
||||
assert calls[0][1] == "test-payload"
|
||||
|
||||
unsub()
|
||||
|
||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await mqtt.async_subscribe(
|
||||
hass, "test-topic", RecordCallsPartial(async_record_calls)
|
||||
)
|
||||
|
||||
|
||||
async def test_subscribe_topic_not_match(
|
||||
|
@ -1300,25 +1233,32 @@ async def test_subscribe_same_topic(
|
|||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
|
||||
calls_a = MagicMock()
|
||||
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
||||
calls_a: list[ReceiveMessage] = []
|
||||
calls_b: list[ReceiveMessage] = []
|
||||
|
||||
def _callback_a(msg: ReceiveMessage) -> None:
|
||||
calls_a.append(msg)
|
||||
|
||||
def _callback_b(msg: ReceiveMessage) -> None:
|
||||
calls_b.append(msg)
|
||||
|
||||
await mqtt.async_subscribe(hass, "test/state", _callback_a)
|
||||
async_fire_mqtt_message(
|
||||
hass, "test/state", "online"
|
||||
) # Simulate a (retained) message
|
||||
await hass.async_block_till_done()
|
||||
assert calls_a.called
|
||||
assert len(calls_a) == 1
|
||||
mqtt_client_mock.subscribe.assert_called()
|
||||
calls_a.reset_mock()
|
||||
calls_a = []
|
||||
mqtt_client_mock.reset_mock()
|
||||
|
||||
calls_b = MagicMock()
|
||||
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
||||
await mqtt.async_subscribe(hass, "test/state", _callback_b)
|
||||
async_fire_mqtt_message(
|
||||
hass, "test/state", "online"
|
||||
) # Simulate a (retained) message
|
||||
await hass.async_block_till_done()
|
||||
assert calls_a.called
|
||||
assert calls_b.called
|
||||
assert len(calls_a) == 1
|
||||
assert len(calls_b) == 1
|
||||
mqtt_client_mock.subscribe.assert_called()
|
||||
|
||||
|
||||
|
@ -1353,19 +1293,25 @@ async def test_unsubscribe_race(
|
|||
# Fake that the client is connected
|
||||
mqtt_mock().connected = True
|
||||
|
||||
calls_a = MagicMock()
|
||||
calls_b = MagicMock()
|
||||
calls_a: list[ReceiveMessage] = []
|
||||
calls_b: list[ReceiveMessage] = []
|
||||
|
||||
def _callback_a(msg: ReceiveMessage) -> None:
|
||||
calls_a.append(msg)
|
||||
|
||||
def _callback_b(msg: ReceiveMessage) -> None:
|
||||
calls_b.append(msg)
|
||||
|
||||
mqtt_client_mock.reset_mock()
|
||||
unsub = await mqtt.async_subscribe(hass, "test/state", calls_a)
|
||||
unsub = await mqtt.async_subscribe(hass, "test/state", _callback_a)
|
||||
unsub()
|
||||
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
||||
await mqtt.async_subscribe(hass, "test/state", _callback_b)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
async_fire_mqtt_message(hass, "test/state", "online")
|
||||
await hass.async_block_till_done()
|
||||
assert not calls_a.called
|
||||
assert calls_b.called
|
||||
assert not calls_a
|
||||
assert calls_b
|
||||
|
||||
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe]
|
||||
expected_calls_1 = [
|
||||
|
@ -1905,7 +1851,7 @@ async def test_custom_birth_message(
|
|||
await mqtt_mock_entry_no_yaml_config()
|
||||
birth = asyncio.Event()
|
||||
|
||||
async def wait_birth(topic, payload, qos) -> None:
|
||||
async def wait_birth(msg: ReceiveMessage) -> None:
|
||||
"""Handle birth message."""
|
||||
birth.set()
|
||||
|
||||
|
@ -1940,7 +1886,7 @@ async def test_default_birth_message(
|
|||
await mqtt_mock_entry_no_yaml_config()
|
||||
birth = asyncio.Event()
|
||||
|
||||
async def wait_birth(topic, payload, qos) -> None:
|
||||
async def wait_birth(msg: ReceiveMessage) -> None:
|
||||
"""Handle birth message."""
|
||||
birth.set()
|
||||
|
||||
|
@ -2015,7 +1961,7 @@ async def test_delayed_birth_message(
|
|||
mqtt_mock = hass.data["mqtt"].client
|
||||
mqtt_mock.reset_mock()
|
||||
|
||||
async def wait_birth(topic, payload, qos) -> None:
|
||||
async def wait_birth(msg: ReceiveMessage) -> None:
|
||||
"""Handle birth message."""
|
||||
birth.set()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue