Remove deprecated callback support for MQTT subscribe (#88543)
* Remove deprecated callback support and fix tests * Add note with removal instruction
This commit is contained in:
parent
24234a55a5
commit
3f79155df6
2 changed files with 61 additions and 166 deletions
|
@ -3,14 +3,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Coroutine, Iterable
|
from collections.abc import Callable, Coroutine, Iterable
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache
|
||||||
import inspect
|
import inspect
|
||||||
from itertools import chain, groupby
|
from itertools import chain, groupby
|
||||||
import logging
|
import logging
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import ssl
|
import ssl
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
|
@ -72,7 +72,6 @@ from .models import (
|
||||||
PublishMessage,
|
PublishMessage,
|
||||||
PublishPayloadType,
|
PublishPayloadType,
|
||||||
ReceiveMessage,
|
ReceiveMessage,
|
||||||
ReceivePayloadType,
|
|
||||||
)
|
)
|
||||||
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
|
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
|
||||||
|
|
||||||
|
@ -148,55 +147,11 @@ async def async_publish(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
AsyncDeprecatedMessageCallbackType = Callable[
|
|
||||||
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
|
|
||||||
]
|
|
||||||
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
|
|
||||||
DeprecatedMessageCallbackTypes = (
|
|
||||||
AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
|
||||||
def wrap_msg_callback(
|
|
||||||
msg_callback: DeprecatedMessageCallbackTypes,
|
|
||||||
) -> AsyncMessageCallbackType | MessageCallbackType:
|
|
||||||
"""Wrap an MQTT message callback to support deprecated signature."""
|
|
||||||
# Check for partials to properly determine if coroutine function
|
|
||||||
check_func = msg_callback
|
|
||||||
while isinstance(check_func, partial):
|
|
||||||
check_func = check_func.func # type: ignore[unreachable]
|
|
||||||
|
|
||||||
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
|
|
||||||
if asyncio.iscoroutinefunction(check_func):
|
|
||||||
|
|
||||||
@wraps(msg_callback)
|
|
||||||
async def async_wrapper(msg: ReceiveMessage) -> None:
|
|
||||||
"""Call with deprecated signature."""
|
|
||||||
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
|
|
||||||
msg.topic, msg.payload, msg.qos
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapper_func = async_wrapper
|
|
||||||
return wrapper_func
|
|
||||||
|
|
||||||
@wraps(msg_callback)
|
|
||||||
def wrapper(msg: ReceiveMessage) -> None:
|
|
||||||
"""Call with deprecated signature."""
|
|
||||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
|
||||||
|
|
||||||
wrapper_func = wrapper
|
|
||||||
|
|
||||||
return wrapper_func
|
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_subscribe(
|
async def async_subscribe(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
topic: str,
|
topic: str,
|
||||||
msg_callback: AsyncMessageCallbackType
|
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
||||||
| MessageCallbackType
|
|
||||||
| DeprecatedMessageCallbackTypes,
|
|
||||||
qos: int = DEFAULT_QOS,
|
qos: int = DEFAULT_QOS,
|
||||||
encoding: str | None = DEFAULT_ENCODING,
|
encoding: str | None = DEFAULT_ENCODING,
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
|
@ -209,8 +164,8 @@ async def async_subscribe(
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
|
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
|
||||||
)
|
)
|
||||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
# Support for a deprecated callback type was removed with HA core 2023.3.0
|
||||||
# Count callback parameters which don't have a default value
|
# The signature validation code can be removed from HA core 2023.5.0
|
||||||
non_default = 0
|
non_default = 0
|
||||||
if msg_callback:
|
if msg_callback:
|
||||||
non_default = sum(
|
non_default = sum(
|
||||||
|
@ -218,26 +173,20 @@ async def async_subscribe(
|
||||||
for _, p in inspect.signature(msg_callback).parameters.items()
|
for _, p in inspect.signature(msg_callback).parameters.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
wrapped_msg_callback = msg_callback
|
# Check for not supported callback signatures
|
||||||
# If we have 3 parameters with no default value, wrap the callback
|
# Can be removed from HA core 2023.5.0
|
||||||
if non_default == 3:
|
if non_default != 1:
|
||||||
module = inspect.getmodule(msg_callback)
|
module = inspect.getmodule(msg_callback)
|
||||||
_LOGGER.warning(
|
raise HomeAssistantError(
|
||||||
(
|
"Signature for MQTT msg_callback '{}.{}' is not supported".format(
|
||||||
"Signature of MQTT msg_callback '%s.%s' is deprecated, "
|
module.__name__ if module else "<unknown>", msg_callback.__name__
|
||||||
"this will stop working with HA core 2023.2"
|
)
|
||||||
),
|
|
||||||
module.__name__ if module else "<unknown>",
|
|
||||||
msg_callback.__name__,
|
|
||||||
)
|
|
||||||
wrapped_msg_callback = wrap_msg_callback(
|
|
||||||
cast(DeprecatedMessageCallbackTypes, msg_callback)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async_remove = await mqtt_data.client.async_subscribe(
|
async_remove = await mqtt_data.client.async_subscribe(
|
||||||
topic,
|
topic,
|
||||||
catch_log_exception(
|
catch_log_exception(
|
||||||
wrapped_msg_callback,
|
msg_callback,
|
||||||
lambda msg: (
|
lambda msg: (
|
||||||
f"Exception in {msg_callback.__name__} when handling msg on "
|
f"Exception in {msg_callback.__name__} when handling msg on "
|
||||||
f"'{msg.topic}': '{msg.payload}'"
|
f"'{msg.topic}': '{msg.payload}'"
|
||||||
|
|
|
@ -920,109 +920,42 @@ async def test_subscribe_bad_topic(
|
||||||
await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type]
|
await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
# Support for a deprecated callback type was removed with HA core 2023.3.0
|
||||||
async def test_subscribe_deprecated(
|
# Test can be removed from HA core 2023.5.0
|
||||||
|
async def test_subscribe_with_deprecated_callback_fails(
|
||||||
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
|
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the subscription of a topic using deprecated callback signature."""
|
"""Test the subscription of a topic using deprecated callback signature fails."""
|
||||||
calls: list[tuple[str, ReceivePayloadType, int]]
|
|
||||||
|
|
||||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
|
||||||
|
|
||||||
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
|
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
|
||||||
"""Record calls."""
|
"""Record calls."""
|
||||||
calls.append((topic, payload, qos))
|
|
||||||
|
|
||||||
calls = []
|
|
||||||
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
assert calls[0][0] == "test-topic"
|
|
||||||
assert calls[0][1] == "test-payload"
|
|
||||||
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
mqtt_mock.async_publish.reset_mock()
|
|
||||||
|
|
||||||
|
with pytest.raises(HomeAssistantError):
|
||||||
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
# Test with partial wrapper
|
# Test with partial wrapper
|
||||||
calls = []
|
with pytest.raises(HomeAssistantError):
|
||||||
unsub = await mqtt.async_subscribe(
|
await mqtt.async_subscribe(hass, "test-topic", RecordCallsPartial(record_calls))
|
||||||
hass, "test-topic", RecordCallsPartial(record_calls)
|
|
||||||
)
|
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
assert calls[0][0] == "test-topic"
|
|
||||||
assert calls[0][1] == "test-payload"
|
|
||||||
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
|
|
||||||
|
|
||||||
# Support for a deprecated callback type will be removed from HA core 2023.2.0
|
# Support for a deprecated callback type was removed with HA core 2023.3.0
|
||||||
async def test_subscribe_deprecated_async(
|
# Test can be removed from HA core 2023.5.0
|
||||||
|
async def test_subscribe_deprecated_async_fails(
|
||||||
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
|
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the subscription of a topic using deprecated coroutine signature."""
|
"""Test the subscription of a topic using deprecated coroutine signature fails."""
|
||||||
calls: list[tuple[str, ReceivePayloadType, int]]
|
|
||||||
|
|
||||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
|
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
|
||||||
"""Record calls."""
|
"""Record calls."""
|
||||||
calls.append((topic, payload, qos))
|
|
||||||
|
|
||||||
calls = []
|
with pytest.raises(HomeAssistantError):
|
||||||
unsub = await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
|
await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
assert calls[0][0] == "test-topic"
|
|
||||||
assert calls[0][1] == "test-payload"
|
|
||||||
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
mqtt_mock.async_publish.reset_mock()
|
|
||||||
|
|
||||||
# Test with partial wrapper
|
# Test with partial wrapper
|
||||||
calls = []
|
with pytest.raises(HomeAssistantError):
|
||||||
unsub = await mqtt.async_subscribe(
|
await mqtt.async_subscribe(
|
||||||
hass, "test-topic", RecordCallsPartial(async_record_calls)
|
hass, "test-topic", RecordCallsPartial(async_record_calls)
|
||||||
)
|
)
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
assert calls[0][0] == "test-topic"
|
|
||||||
assert calls[0][1] == "test-payload"
|
|
||||||
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test-topic", "test-payload")
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert len(calls) == 1
|
|
||||||
|
|
||||||
|
|
||||||
async def test_subscribe_topic_not_match(
|
async def test_subscribe_topic_not_match(
|
||||||
|
@ -1300,25 +1233,32 @@ async def test_subscribe_same_topic(
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock().connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
calls_a = MagicMock()
|
calls_a: list[ReceiveMessage] = []
|
||||||
await mqtt.async_subscribe(hass, "test/state", calls_a)
|
calls_b: list[ReceiveMessage] = []
|
||||||
|
|
||||||
|
def _callback_a(msg: ReceiveMessage) -> None:
|
||||||
|
calls_a.append(msg)
|
||||||
|
|
||||||
|
def _callback_b(msg: ReceiveMessage) -> None:
|
||||||
|
calls_b.append(msg)
|
||||||
|
|
||||||
|
await mqtt.async_subscribe(hass, "test/state", _callback_a)
|
||||||
async_fire_mqtt_message(
|
async_fire_mqtt_message(
|
||||||
hass, "test/state", "online"
|
hass, "test/state", "online"
|
||||||
) # Simulate a (retained) message
|
) # Simulate a (retained) message
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert calls_a.called
|
assert len(calls_a) == 1
|
||||||
mqtt_client_mock.subscribe.assert_called()
|
mqtt_client_mock.subscribe.assert_called()
|
||||||
calls_a.reset_mock()
|
calls_a = []
|
||||||
mqtt_client_mock.reset_mock()
|
mqtt_client_mock.reset_mock()
|
||||||
|
|
||||||
calls_b = MagicMock()
|
await mqtt.async_subscribe(hass, "test/state", _callback_b)
|
||||||
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
|
||||||
async_fire_mqtt_message(
|
async_fire_mqtt_message(
|
||||||
hass, "test/state", "online"
|
hass, "test/state", "online"
|
||||||
) # Simulate a (retained) message
|
) # Simulate a (retained) message
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert calls_a.called
|
assert len(calls_a) == 1
|
||||||
assert calls_b.called
|
assert len(calls_b) == 1
|
||||||
mqtt_client_mock.subscribe.assert_called()
|
mqtt_client_mock.subscribe.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1353,19 +1293,25 @@ async def test_unsubscribe_race(
|
||||||
# Fake that the client is connected
|
# Fake that the client is connected
|
||||||
mqtt_mock().connected = True
|
mqtt_mock().connected = True
|
||||||
|
|
||||||
calls_a = MagicMock()
|
calls_a: list[ReceiveMessage] = []
|
||||||
calls_b = MagicMock()
|
calls_b: list[ReceiveMessage] = []
|
||||||
|
|
||||||
|
def _callback_a(msg: ReceiveMessage) -> None:
|
||||||
|
calls_a.append(msg)
|
||||||
|
|
||||||
|
def _callback_b(msg: ReceiveMessage) -> None:
|
||||||
|
calls_b.append(msg)
|
||||||
|
|
||||||
mqtt_client_mock.reset_mock()
|
mqtt_client_mock.reset_mock()
|
||||||
unsub = await mqtt.async_subscribe(hass, "test/state", calls_a)
|
unsub = await mqtt.async_subscribe(hass, "test/state", _callback_a)
|
||||||
unsub()
|
unsub()
|
||||||
await mqtt.async_subscribe(hass, "test/state", calls_b)
|
await mqtt.async_subscribe(hass, "test/state", _callback_b)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
async_fire_mqtt_message(hass, "test/state", "online")
|
async_fire_mqtt_message(hass, "test/state", "online")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert not calls_a.called
|
assert not calls_a
|
||||||
assert calls_b.called
|
assert calls_b
|
||||||
|
|
||||||
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe]
|
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe]
|
||||||
expected_calls_1 = [
|
expected_calls_1 = [
|
||||||
|
@ -1905,7 +1851,7 @@ async def test_custom_birth_message(
|
||||||
await mqtt_mock_entry_no_yaml_config()
|
await mqtt_mock_entry_no_yaml_config()
|
||||||
birth = asyncio.Event()
|
birth = asyncio.Event()
|
||||||
|
|
||||||
async def wait_birth(topic, payload, qos) -> None:
|
async def wait_birth(msg: ReceiveMessage) -> None:
|
||||||
"""Handle birth message."""
|
"""Handle birth message."""
|
||||||
birth.set()
|
birth.set()
|
||||||
|
|
||||||
|
@ -1940,7 +1886,7 @@ async def test_default_birth_message(
|
||||||
await mqtt_mock_entry_no_yaml_config()
|
await mqtt_mock_entry_no_yaml_config()
|
||||||
birth = asyncio.Event()
|
birth = asyncio.Event()
|
||||||
|
|
||||||
async def wait_birth(topic, payload, qos) -> None:
|
async def wait_birth(msg: ReceiveMessage) -> None:
|
||||||
"""Handle birth message."""
|
"""Handle birth message."""
|
||||||
birth.set()
|
birth.set()
|
||||||
|
|
||||||
|
@ -2015,7 +1961,7 @@ async def test_delayed_birth_message(
|
||||||
mqtt_mock = hass.data["mqtt"].client
|
mqtt_mock = hass.data["mqtt"].client
|
||||||
mqtt_mock.reset_mock()
|
mqtt_mock.reset_mock()
|
||||||
|
|
||||||
async def wait_birth(topic, payload, qos) -> None:
|
async def wait_birth(msg: ReceiveMessage) -> None:
|
||||||
"""Handle birth message."""
|
"""Handle birth message."""
|
||||||
birth.set()
|
birth.set()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue