Remove deprecated callback support for MQTT subscribe (#88543)

* Remove deprecated callback support and fix tests

* Add note with removal instruction
This commit is contained in:
Jan Bouwhuis 2023-02-21 22:21:00 +01:00 committed by GitHub
parent 24234a55a5
commit 3f79155df6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 166 deletions

View file

@ -3,14 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps from functools import lru_cache
import inspect import inspect
from itertools import chain, groupby from itertools import chain, groupby
import logging import logging
from operator import attrgetter from operator import attrgetter
import ssl import ssl
import time import time
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
import uuid import uuid
import async_timeout import async_timeout
@ -72,7 +72,6 @@ from .models import (
PublishMessage, PublishMessage,
PublishPayloadType, PublishPayloadType,
ReceiveMessage, ReceiveMessage,
ReceivePayloadType,
) )
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
@ -148,55 +147,11 @@ async def async_publish(
) )
AsyncDeprecatedMessageCallbackType = Callable[
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
DeprecatedMessageCallbackTypes = (
AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType
)
# Support for a deprecated callback type will be removed from HA core 2023.2.0
def wrap_msg_callback(
msg_callback: DeprecatedMessageCallbackTypes,
) -> AsyncMessageCallbackType | MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function
check_func = msg_callback
while isinstance(check_func, partial):
check_func = check_func.func # type: ignore[unreachable]
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback)
async def async_wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos
)
wrapper_func = async_wrapper
return wrapper_func
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
wrapper_func = wrapper
return wrapper_func
@bind_hass @bind_hass
async def async_subscribe( async def async_subscribe(
hass: HomeAssistant, hass: HomeAssistant,
topic: str, topic: str,
msg_callback: AsyncMessageCallbackType msg_callback: AsyncMessageCallbackType | MessageCallbackType,
| MessageCallbackType
| DeprecatedMessageCallbackTypes,
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
@ -209,8 +164,8 @@ async def async_subscribe(
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled" f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
) )
# Support for a deprecated callback type will be removed from HA core 2023.2.0 # Support for a deprecated callback type was removed with HA core 2023.3.0
# Count callback parameters which don't have a default value # The signature validation code can be removed from HA core 2023.5.0
non_default = 0 non_default = 0
if msg_callback: if msg_callback:
non_default = sum( non_default = sum(
@ -218,26 +173,20 @@ async def async_subscribe(
for _, p in inspect.signature(msg_callback).parameters.items() for _, p in inspect.signature(msg_callback).parameters.items()
) )
wrapped_msg_callback = msg_callback # Check for not supported callback signatures
# If we have 3 parameters with no default value, wrap the callback # Can be removed from HA core 2023.5.0
if non_default == 3: if non_default != 1:
module = inspect.getmodule(msg_callback) module = inspect.getmodule(msg_callback)
_LOGGER.warning( raise HomeAssistantError(
( "Signature for MQTT msg_callback '{}.{}' is not supported".format(
"Signature of MQTT msg_callback '%s.%s' is deprecated, " module.__name__ if module else "<unknown>", msg_callback.__name__
"this will stop working with HA core 2023.2" )
),
module.__name__ if module else "<unknown>",
msg_callback.__name__,
)
wrapped_msg_callback = wrap_msg_callback(
cast(DeprecatedMessageCallbackTypes, msg_callback)
) )
async_remove = await mqtt_data.client.async_subscribe( async_remove = await mqtt_data.client.async_subscribe(
topic, topic,
catch_log_exception( catch_log_exception(
wrapped_msg_callback, msg_callback,
lambda msg: ( lambda msg: (
f"Exception in {msg_callback.__name__} when handling msg on " f"Exception in {msg_callback.__name__} when handling msg on "
f"'{msg.topic}': '{msg.payload}'" f"'{msg.topic}': '{msg.payload}'"

View file

@ -920,109 +920,42 @@ async def test_subscribe_bad_topic(
await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type] await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type]
# Support for a deprecated callback type will be removed from HA core 2023.2.0 # Support for a deprecated callback type was removed with HA core 2023.3.0
async def test_subscribe_deprecated( # Test can be removed from HA core 2023.5.0
async def test_subscribe_with_deprecated_callback_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None: ) -> None:
"""Test the subscription of a topic using deprecated callback signature.""" """Test the subscription of a topic using deprecated callback signature fails."""
calls: list[tuple[str, ReceivePayloadType, int]]
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None: async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
"""Record calls.""" """Record calls."""
calls.append((topic, payload, qos))
calls = []
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
mqtt_mock.async_publish.reset_mock()
with pytest.raises(HomeAssistantError):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
# Test with partial wrapper # Test with partial wrapper
calls = [] with pytest.raises(HomeAssistantError):
unsub = await mqtt.async_subscribe( await mqtt.async_subscribe(hass, "test-topic", RecordCallsPartial(record_calls))
hass, "test-topic", RecordCallsPartial(record_calls)
)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
# Support for a deprecated callback type will be removed from HA core 2023.2.0 # Support for a deprecated callback type was removed with HA core 2023.3.0
async def test_subscribe_deprecated_async( # Test can be removed from HA core 2023.5.0
async def test_subscribe_deprecated_async_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None: ) -> None:
"""Test the subscription of a topic using deprecated coroutine signature.""" """Test the subscription of a topic using deprecated coroutine signature fails."""
calls: list[tuple[str, ReceivePayloadType, int]]
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
@callback @callback
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None: def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
"""Record calls.""" """Record calls."""
calls.append((topic, payload, qos))
calls = [] with pytest.raises(HomeAssistantError):
unsub = await mqtt.async_subscribe(hass, "test-topic", async_record_calls) await mqtt.async_subscribe(hass, "test-topic", async_record_calls)
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
mqtt_mock.async_publish.reset_mock()
# Test with partial wrapper # Test with partial wrapper
calls = [] with pytest.raises(HomeAssistantError):
unsub = await mqtt.async_subscribe( await mqtt.async_subscribe(
hass, "test-topic", RecordCallsPartial(async_record_calls) hass, "test-topic", RecordCallsPartial(async_record_calls)
) )
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload"
unsub()
async_fire_mqtt_message(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_subscribe_topic_not_match( async def test_subscribe_topic_not_match(
@ -1300,25 +1233,32 @@ async def test_subscribe_same_topic(
# Fake that the client is connected # Fake that the client is connected
mqtt_mock().connected = True mqtt_mock().connected = True
calls_a = MagicMock() calls_a: list[ReceiveMessage] = []
await mqtt.async_subscribe(hass, "test/state", calls_a) calls_b: list[ReceiveMessage] = []
def _callback_a(msg: ReceiveMessage) -> None:
calls_a.append(msg)
def _callback_b(msg: ReceiveMessage) -> None:
calls_b.append(msg)
await mqtt.async_subscribe(hass, "test/state", _callback_a)
async_fire_mqtt_message( async_fire_mqtt_message(
hass, "test/state", "online" hass, "test/state", "online"
) # Simulate a (retained) message ) # Simulate a (retained) message
await hass.async_block_till_done() await hass.async_block_till_done()
assert calls_a.called assert len(calls_a) == 1
mqtt_client_mock.subscribe.assert_called() mqtt_client_mock.subscribe.assert_called()
calls_a.reset_mock() calls_a = []
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
calls_b = MagicMock() await mqtt.async_subscribe(hass, "test/state", _callback_b)
await mqtt.async_subscribe(hass, "test/state", calls_b)
async_fire_mqtt_message( async_fire_mqtt_message(
hass, "test/state", "online" hass, "test/state", "online"
) # Simulate a (retained) message ) # Simulate a (retained) message
await hass.async_block_till_done() await hass.async_block_till_done()
assert calls_a.called assert len(calls_a) == 1
assert calls_b.called assert len(calls_b) == 1
mqtt_client_mock.subscribe.assert_called() mqtt_client_mock.subscribe.assert_called()
@ -1353,19 +1293,25 @@ async def test_unsubscribe_race(
# Fake that the client is connected # Fake that the client is connected
mqtt_mock().connected = True mqtt_mock().connected = True
calls_a = MagicMock() calls_a: list[ReceiveMessage] = []
calls_b = MagicMock() calls_b: list[ReceiveMessage] = []
def _callback_a(msg: ReceiveMessage) -> None:
calls_a.append(msg)
def _callback_b(msg: ReceiveMessage) -> None:
calls_b.append(msg)
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
unsub = await mqtt.async_subscribe(hass, "test/state", calls_a) unsub = await mqtt.async_subscribe(hass, "test/state", _callback_a)
unsub() unsub()
await mqtt.async_subscribe(hass, "test/state", calls_b) await mqtt.async_subscribe(hass, "test/state", _callback_b)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_mqtt_message(hass, "test/state", "online") async_fire_mqtt_message(hass, "test/state", "online")
await hass.async_block_till_done() await hass.async_block_till_done()
assert not calls_a.called assert not calls_a
assert calls_b.called assert calls_b
# We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe] # We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe]
expected_calls_1 = [ expected_calls_1 = [
@ -1905,7 +1851,7 @@ async def test_custom_birth_message(
await mqtt_mock_entry_no_yaml_config() await mqtt_mock_entry_no_yaml_config()
birth = asyncio.Event() birth = asyncio.Event()
async def wait_birth(topic, payload, qos) -> None: async def wait_birth(msg: ReceiveMessage) -> None:
"""Handle birth message.""" """Handle birth message."""
birth.set() birth.set()
@ -1940,7 +1886,7 @@ async def test_default_birth_message(
await mqtt_mock_entry_no_yaml_config() await mqtt_mock_entry_no_yaml_config()
birth = asyncio.Event() birth = asyncio.Event()
async def wait_birth(topic, payload, qos) -> None: async def wait_birth(msg: ReceiveMessage) -> None:
"""Handle birth message.""" """Handle birth message."""
birth.set() birth.set()
@ -2015,7 +1961,7 @@ async def test_delayed_birth_message(
mqtt_mock = hass.data["mqtt"].client mqtt_mock = hass.data["mqtt"].client
mqtt_mock.reset_mock() mqtt_mock.reset_mock()
async def wait_birth(topic, payload, qos) -> None: async def wait_birth(msg: ReceiveMessage) -> None:
"""Handle birth message.""" """Handle birth message."""
birth.set() birth.set()