diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index dfc88844bd6..ad89a35ec0a 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -3,14 +3,14 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Coroutine, Iterable -from functools import lru_cache, partial, wraps +from functools import lru_cache import inspect from itertools import chain, groupby import logging from operator import attrgetter import ssl import time -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import uuid import async_timeout @@ -72,7 +72,6 @@ from .models import ( PublishMessage, PublishPayloadType, ReceiveMessage, - ReceivePayloadType, ) from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled @@ -148,55 +147,11 @@ async def async_publish( ) -AsyncDeprecatedMessageCallbackType = Callable[ - [str, ReceivePayloadType, int], Coroutine[Any, Any, None] -] -DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] -DeprecatedMessageCallbackTypes = ( - AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType -) - - -# Support for a deprecated callback type will be removed from HA core 2023.2.0 -def wrap_msg_callback( - msg_callback: DeprecatedMessageCallbackTypes, -) -> AsyncMessageCallbackType | MessageCallbackType: - """Wrap an MQTT message callback to support deprecated signature.""" - # Check for partials to properly determine if coroutine function - check_func = msg_callback - while isinstance(check_func, partial): - check_func = check_func.func # type: ignore[unreachable] - - wrapper_func: AsyncMessageCallbackType | MessageCallbackType - if asyncio.iscoroutinefunction(check_func): - - @wraps(msg_callback) - async def async_wrapper(msg: ReceiveMessage) -> None: - """Call with deprecated signature.""" - await cast(AsyncDeprecatedMessageCallbackType, msg_callback)( - msg.topic, msg.payload, msg.qos - ) - - wrapper_func = async_wrapper - return wrapper_func - - @wraps(msg_callback) - def wrapper(msg: ReceiveMessage) -> None: - """Call with deprecated signature.""" - msg_callback(msg.topic, msg.payload, msg.qos) - - wrapper_func = wrapper - - return wrapper_func - - @bind_hass async def async_subscribe( hass: HomeAssistant, topic: str, - msg_callback: AsyncMessageCallbackType - | MessageCallbackType - | DeprecatedMessageCallbackTypes, + msg_callback: AsyncMessageCallbackType | MessageCallbackType, qos: int = DEFAULT_QOS, encoding: str | None = DEFAULT_ENCODING, ) -> CALLBACK_TYPE: @@ -209,8 +164,8 @@ async def async_subscribe( raise HomeAssistantError( f"Cannot subscribe to topic '{topic}', MQTT is not enabled" ) - # Support for a deprecated callback type will be removed from HA core 2023.2.0 - # Count callback parameters which don't have a default value + # Support for a deprecated callback type was removed with HA core 2023.3.0 + # The signature validation code can be removed from HA core 2023.5.0 non_default = 0 if msg_callback: non_default = sum( @@ -218,26 +173,20 @@ async def async_subscribe( for _, p in inspect.signature(msg_callback).parameters.items() ) - wrapped_msg_callback = msg_callback - # If we have 3 parameters with no default value, wrap the callback - if non_default == 3: + # Check for not supported callback signatures + # Can be removed from HA core 2023.5.0 + if non_default != 1: module = inspect.getmodule(msg_callback) - _LOGGER.warning( - ( - "Signature of MQTT msg_callback '%s.%s' is deprecated, " - "this will stop working with HA core 2023.2" - ), - module.__name__ if module else "", - msg_callback.__name__, - ) - wrapped_msg_callback = wrap_msg_callback( - cast(DeprecatedMessageCallbackTypes, msg_callback) + raise HomeAssistantError( + "Signature for MQTT msg_callback '{}.{}' is not supported".format( + module.__name__ if module else "", msg_callback.__name__ + ) ) async_remove = await mqtt_data.client.async_subscribe( topic, catch_log_exception( - wrapped_msg_callback, + msg_callback, lambda msg: ( f"Exception in {msg_callback.__name__} when handling msg on " f"'{msg.topic}': '{msg.payload}'" diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 8462e46318a..47f8743f502 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -920,109 +920,42 @@ async def test_subscribe_bad_topic( await mqtt.async_subscribe(hass, 55, record_calls) # type: ignore[arg-type] -# Support for a deprecated callback type will be removed from HA core 2023.2.0 -async def test_subscribe_deprecated( +# Support for a deprecated callback type was removed with HA core 2023.3.0 +# Test can be removed from HA core 2023.5.0 +async def test_subscribe_with_deprecated_callback_fails( hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator ) -> None: - """Test the subscription of a topic using deprecated callback signature.""" - calls: list[tuple[str, ReceivePayloadType, int]] - - mqtt_mock = await mqtt_mock_entry_no_yaml_config() + """Test the subscription of a topic using deprecated callback signature fails.""" async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None: """Record calls.""" - calls.append((topic, payload, qos)) - - calls = [] - unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 - assert calls[0][0] == "test-topic" - assert calls[0][1] == "test-payload" - - unsub() - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 - mqtt_mock.async_publish.reset_mock() + with pytest.raises(HomeAssistantError): + await mqtt.async_subscribe(hass, "test-topic", record_calls) # Test with partial wrapper - calls = [] - unsub = await mqtt.async_subscribe( - hass, "test-topic", RecordCallsPartial(record_calls) - ) - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 - assert calls[0][0] == "test-topic" - assert calls[0][1] == "test-payload" - - unsub() - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 + with pytest.raises(HomeAssistantError): + await mqtt.async_subscribe(hass, "test-topic", RecordCallsPartial(record_calls)) -# Support for a deprecated callback type will be removed from HA core 2023.2.0 -async def test_subscribe_deprecated_async( +# Support for a deprecated callback type was removed with HA core 2023.3.0 +# Test can be removed from HA core 2023.5.0 +async def test_subscribe_deprecated_async_fails( hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator ) -> None: - """Test the subscription of a topic using deprecated coroutine signature.""" - calls: list[tuple[str, ReceivePayloadType, int]] - - mqtt_mock = await mqtt_mock_entry_no_yaml_config() + """Test the subscription of a topic using deprecated coroutine signature fails.""" @callback def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None: """Record calls.""" - calls.append((topic, payload, qos)) - calls = [] - unsub = await mqtt.async_subscribe(hass, "test-topic", async_record_calls) - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 - assert calls[0][0] == "test-topic" - assert calls[0][1] == "test-payload" - - unsub() - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 - mqtt_mock.async_publish.reset_mock() + with pytest.raises(HomeAssistantError): + await mqtt.async_subscribe(hass, "test-topic", async_record_calls) # Test with partial wrapper - calls = [] - unsub = await mqtt.async_subscribe( - hass, "test-topic", RecordCallsPartial(async_record_calls) - ) - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 - assert calls[0][0] == "test-topic" - assert calls[0][1] == "test-payload" - - unsub() - - async_fire_mqtt_message(hass, "test-topic", "test-payload") - - await hass.async_block_till_done() - assert len(calls) == 1 + with pytest.raises(HomeAssistantError): + await mqtt.async_subscribe( + hass, "test-topic", RecordCallsPartial(async_record_calls) + ) async def test_subscribe_topic_not_match( @@ -1300,25 +1233,32 @@ async def test_subscribe_same_topic( # Fake that the client is connected mqtt_mock().connected = True - calls_a = MagicMock() - await mqtt.async_subscribe(hass, "test/state", calls_a) + calls_a: list[ReceiveMessage] = [] + calls_b: list[ReceiveMessage] = [] + + def _callback_a(msg: ReceiveMessage) -> None: + calls_a.append(msg) + + def _callback_b(msg: ReceiveMessage) -> None: + calls_b.append(msg) + + await mqtt.async_subscribe(hass, "test/state", _callback_a) async_fire_mqtt_message( hass, "test/state", "online" ) # Simulate a (retained) message await hass.async_block_till_done() - assert calls_a.called + assert len(calls_a) == 1 mqtt_client_mock.subscribe.assert_called() - calls_a.reset_mock() + calls_a = [] mqtt_client_mock.reset_mock() - calls_b = MagicMock() - await mqtt.async_subscribe(hass, "test/state", calls_b) + await mqtt.async_subscribe(hass, "test/state", _callback_b) async_fire_mqtt_message( hass, "test/state", "online" ) # Simulate a (retained) message await hass.async_block_till_done() - assert calls_a.called - assert calls_b.called + assert len(calls_a) == 1 + assert len(calls_b) == 1 mqtt_client_mock.subscribe.assert_called() @@ -1353,19 +1293,25 @@ async def test_unsubscribe_race( # Fake that the client is connected mqtt_mock().connected = True - calls_a = MagicMock() - calls_b = MagicMock() + calls_a: list[ReceiveMessage] = [] + calls_b: list[ReceiveMessage] = [] + + def _callback_a(msg: ReceiveMessage) -> None: + calls_a.append(msg) + + def _callback_b(msg: ReceiveMessage) -> None: + calls_b.append(msg) mqtt_client_mock.reset_mock() - unsub = await mqtt.async_subscribe(hass, "test/state", calls_a) + unsub = await mqtt.async_subscribe(hass, "test/state", _callback_a) unsub() - await mqtt.async_subscribe(hass, "test/state", calls_b) + await mqtt.async_subscribe(hass, "test/state", _callback_b) await hass.async_block_till_done() async_fire_mqtt_message(hass, "test/state", "online") await hass.async_block_till_done() - assert not calls_a.called - assert calls_b.called + assert not calls_a + assert calls_b # We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe] expected_calls_1 = [ @@ -1905,7 +1851,7 @@ async def test_custom_birth_message( await mqtt_mock_entry_no_yaml_config() birth = asyncio.Event() - async def wait_birth(topic, payload, qos) -> None: + async def wait_birth(msg: ReceiveMessage) -> None: """Handle birth message.""" birth.set() @@ -1940,7 +1886,7 @@ async def test_default_birth_message( await mqtt_mock_entry_no_yaml_config() birth = asyncio.Event() - async def wait_birth(topic, payload, qos) -> None: + async def wait_birth(msg: ReceiveMessage) -> None: """Handle birth message.""" birth.set() @@ -2015,7 +1961,7 @@ async def test_delayed_birth_message( mqtt_mock = hass.data["mqtt"].client mqtt_mock.reset_mock() - async def wait_birth(topic, payload, qos) -> None: + async def wait_birth(msg: ReceiveMessage) -> None: """Handle birth message.""" birth.set()