Remove deprecated callback support for MQTT subscribe (#88543)

* Remove deprecated callback support and fix tests

* Add note with removal instruction
This commit is contained in:
Jan Bouwhuis 2023-02-21 22:21:00 +01:00 committed by GitHub
parent 24234a55a5
commit 3f79155df6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 166 deletions

View file

@ -3,14 +3,14 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps
from functools import lru_cache
import inspect
from itertools import chain, groupby
import logging
from operator import attrgetter
import ssl
import time
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
import uuid
import async_timeout
@ -72,7 +72,6 @@ from .models import (
PublishMessage,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
@ -148,55 +147,11 @@ async def async_publish(
)
AsyncDeprecatedMessageCallbackType = Callable[
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
DeprecatedMessageCallbackTypes = (
AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType
)
# Support for a deprecated callback type will be removed from HA core 2023.2.0
def wrap_msg_callback(
msg_callback: DeprecatedMessageCallbackTypes,
) -> AsyncMessageCallbackType | MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function
check_func = msg_callback
while isinstance(check_func, partial):
check_func = check_func.func # type: ignore[unreachable]
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback)
async def async_wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos
)
wrapper_func = async_wrapper
return wrapper_func
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
wrapper_func = wrapper
return wrapper_func
@bind_hass
async def async_subscribe(
hass: HomeAssistant,
topic: str,
msg_callback: AsyncMessageCallbackType
| MessageCallbackType
| DeprecatedMessageCallbackTypes,
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
@ -209,8 +164,8 @@ async def async_subscribe(
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
)
# Support for a deprecated callback type will be removed from HA core 2023.2.0
# Count callback parameters which don't have a default value
# Support for a deprecated callback type was removed with HA core 2023.3.0
# The signature validation code can be removed from HA core 2023.5.0
non_default = 0
if msg_callback:
non_default = sum(
@ -218,26 +173,20 @@ async def async_subscribe(
for _, p in inspect.signature(msg_callback).parameters.items()
)
wrapped_msg_callback = msg_callback
# If we have 3 parameters with no default value, wrap the callback
if non_default == 3:
# Check for not supported callback signatures
# Can be removed from HA core 2023.5.0
if non_default != 1:
module = inspect.getmodule(msg_callback)
_LOGGER.warning(
(
"Signature of MQTT msg_callback '%s.%s' is deprecated, "
"this will stop working with HA core 2023.2"
),
module.__name__ if module else "<unknown>",
msg_callback.__name__,
)
wrapped_msg_callback = wrap_msg_callback(
cast(DeprecatedMessageCallbackTypes, msg_callback)
raise HomeAssistantError(
"Signature for MQTT msg_callback '{}.{}' is not supported".format(
module.__name__ if module else "<unknown>", msg_callback.__name__
)
)
async_remove = await mqtt_data.client.async_subscribe(
topic,
catch_log_exception(
wrapped_msg_callback,
msg_callback,
lambda msg: (
f"Exception in {msg_callback.__name__} when handling msg on "
f"'{msg.topic}': '{msg.payload}'"

View file

@ -920,109 +920,42 @@ async def test_subscribe_bad_topic(
await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type]
# Support for a deprecated callback type will be removed from HA core 2023.2.0
async def test_subscribe_deprecated(
# Support for a deprecated callback type was removed with HA core 2023.3.0
# Test can be removed from HA core 2023.5.0
async def test_subscribe_with_deprecated_callback_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None:
"""Test the subscription of a topic using deprecated callback signature."""
calls: list[tuple[str, ReceivePayloadType, int]]
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
"""Test the subscription of a topic using deprecated callback signature fails."""
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
"""Record calls."""
calls.append((topic, payload, qos))
calls = []
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
mqtt_mock.async_publish.reset_mock()
with pytest.raises(HomeAssistantError):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
# Test with partial wrapper
calls = []
unsub = await mqtt.async_subscribe(
hass, "test-topic", RecordCallsPartial(record_calls)
)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
with pytest.raises(HomeAssistantError):
await mqtt.async_subscribe(hass, "test-topic", RecordCallsPartial(record_calls))
# Support for a deprecated callback type will be removed from HA core 2023.2.0
async def test_subscribe_deprecated_async(
# Support for a deprecated callback type was removed with HA core 2023.3.0
# Test can be removed from HA core 2023.5.0
async def test_subscribe_deprecated_async_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None:
"""Test the subscription of a topic using deprecated coroutine signature."""
calls: list[tuple[str, ReceivePayloadType, int]]
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
"""Test the subscription of a topic using deprecated coroutine signature fails."""
@callback
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
"""Record calls."""
calls.append((topic, payload, qos))
calls = []
unsub = await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
mqtt_mock.async_publish.reset_mock()
with pytest.raises(HomeAssistantError):
await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
# Test with partial wrapper
calls = []
unsub = await mqtt.async_subscribe(
hass, "test-topic", RecordCallsPartial(async_record_calls)
)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
with pytest.raises(HomeAssistantError):
await mqtt.async_subscribe(
hass, "test-topic", RecordCallsPartial(async_record_calls)
)
async def test_subscribe_topic_not_match(
@ -1300,25 +1233,32 @@ async def test_subscribe_same_topic(
# Fake that the client is connected
mqtt_mock().connected = True
calls_a = MagicMock()
await mqtt.async_subscribe(hass, "test/state", calls_a)
calls_a: list[ReceiveMessage] = []
calls_b: list[ReceiveMessage] = []
def _callback_a(msg: ReceiveMessage) -> None:
calls_a.append(msg)
def _callback_b(msg: ReceiveMessage) -> None:
calls_b.append(msg)
await mqtt.async_subscribe(hass, "test/state", _callback_a)
async_fire_mqtt_message(
hass, "test/state", "online"
) # Simulate a (retained) message
await hass.async_block_till_done()
assert calls_a.called
assert len(calls_a) == 1
mqtt_client_mock.subscribe.assert_called()
calls_a.reset_mock()
calls_a = []
mqtt_client_mock.reset_mock()
calls_b = MagicMock()
await mqtt.async_subscribe(hass, "test/state", calls_b)
await mqtt.async_subscribe(hass, "test/state", _callback_b)
async_fire_mqtt_message(
hass, "test/state", "online"
) # Simulate a (retained) message
await hass.async_block_till_done()
assert calls_a.called
assert calls_b.called
assert len(calls_a) == 1
assert len(calls_b) == 1
mqtt_client_mock.subscribe.assert_called()
@ -1353,19 +1293,25 @@ async def test_unsubscribe_race(
# Fake that the client is connected
mqtt_mock().connected = True
calls_a = MagicMock()
calls_b = MagicMock()
calls_a: list[ReceiveMessage] = []
calls_b: list[ReceiveMessage] = []
def _callback_a(msg: ReceiveMessage) -> None:
calls_a.append(msg)
def _callback_b(msg: ReceiveMessage) -> None:
calls_b.append(msg)
mqtt_client_mock.reset_mock()
unsub = await mqtt.async_subscribe(hass, "test/state", calls_a)
unsub = await mqtt.async_subscribe(hass, "test/state", _callback_a)
unsub()
await mqtt.async_subscribe(hass, "test/state", calls_b)
await mqtt.async_subscribe(hass, "test/state", _callback_b)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test/state", "online")
await hass.async_block_till_done()
assert not calls_a.called
assert calls_b.called
assert not calls_a
assert calls_b
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe]
expected_calls_1 = [
@ -1905,7 +1851,7 @@ async def test_custom_birth_message(
await mqtt_mock_entry_no_yaml_config()
birth = asyncio.Event()
async def wait_birth(topic, payload, qos) -> None:
async def wait_birth(msg: ReceiveMessage) -> None:
"""Handle birth message."""
birth.set()
@ -1940,7 +1886,7 @@ async def test_default_birth_message(
await mqtt_mock_entry_no_yaml_config()
birth = asyncio.Event()
async def wait_birth(topic, payload, qos) -> None:
async def wait_birth(msg: ReceiveMessage) -> None:
"""Handle birth message."""
birth.set()
@ -2015,7 +1961,7 @@ async def test_delayed_birth_message(
mqtt_mock = hass.data["mqtt"].client
mqtt_mock.reset_mock()
async def wait_birth(topic, payload, qos) -> None:
async def wait_birth(msg: ReceiveMessage) -> None:
"""Handle birth message."""
birth.set()