From ec1b8b616f73d5b7f7021cc0fc405be34a5ed115 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Tue, 14 Mar 2023 11:13:55 +0100 Subject: [PATCH] Debounce and group MQTT subscriptions (#88862) * Debounce and group mqtt subscriptions * Cleanup * Do not cooldown on resubscribe * Remove lock from task Co-authored-by: Erik Montnemery * ruff * Longer initial cool down. Manages unsubscribes * Own lock for access to self._pending_subscriptions * adjust * Subscribe to highest QoS when sharing subscription * do not block _pending_subscriptions_lock with io * Test the highest qos is subscribed at * Cleanup max qos * Follow up comments part 1 * Make docstr more generic * Make max qos update thread safe * Add lock on clearing _max_qos when resubscribing * Wait for linger task * User copy * Check for key before cleaning up * Fix lingering task * Do not use a lock * do not await _async_queue_subscriptions * Replace copy with assignment * Update max qos before returning * Do not iterate if max_qos == 0 * Do not ieterate subs if max qos == 0 * Set initial cooldown correctly * Ensure discovery cooldown ends after subscribing * plan last subscribe with debouncer timeout * cooldown if self._pending_subscriptions is set * Revert format changes * Remove stale assingnment self._last_subscribe * Remove not used property * Also check while for pending subscriptions * revert first added sleep() * Optimize --------- Co-authored-by: Erik Montnemery Co-authored-by: J. Nick Koston --- homeassistant/components/mqtt/client.py | 151 +++++++++++++-- tests/components/mqtt/test_discovery.py | 6 + tests/components/mqtt/test_init.py | 242 ++++++++++++++++++++++-- 3 files changed, 361 insertions(+), 38 deletions(-) diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index e717da5144c..5585a6cee5f 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -83,6 +83,8 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) DISCOVERY_COOLDOWN = 2 +INITIAL_SUBSCRIBE_COOLDOWN = 1.0 +SUBSCRIBE_COOLDOWN = 0.1 TIMEOUT_ACK = 10 SubscribePayloadType = str | bytes # Only bytes if encoding is None @@ -295,10 +297,86 @@ def _is_simple_match(topic: str) -> bool: return not ("+" in topic or "#" in topic) +class EnsureJobAfterCooldown: + """Ensure a cool down period before executing a job. + + When a new execute request arrives we cancel the current request + and start a new one. + """ + + def __init__( + self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]] + ) -> None: + """Initialize the timer.""" + self._loop = asyncio.get_running_loop() + self._timeout = timeout + self._callback = callback_job + self._task: asyncio.Future | None = None + self._timer: asyncio.TimerHandle | None = None + + def set_timeout(self, timeout: float) -> None: + """Set a new timeout period.""" + self._timeout = timeout + + async def _async_job(self) -> None: + """Execute after a cooldown period.""" + try: + await self._callback() + except HomeAssistantError as ha_error: + _LOGGER.error("%s", ha_error) + + @callback + def _async_task_done(self, task: asyncio.Future) -> None: + """Handle task done.""" + self._task = None + + @callback + def _async_execute(self) -> None: + """Execute the job.""" + if self._task: + # Task already running, + # so we schedule another run + self.async_schedule() + return + + self._async_cancel_timer() + self._task = asyncio.create_task(self._async_job()) + self._task.add_done_callback(self._async_task_done) + + @callback + def _async_cancel_timer(self) -> None: + """Cancel any pending task.""" + if self._timer: + self._timer.cancel() + self._timer = None + + @callback + def async_schedule(self) -> None: + """Ensure we execute after a cooldown period.""" + # We want to reschedule the timer in the future + # every time this is called. + self._async_cancel_timer() + self._timer = self._loop.call_later(self._timeout, self._async_execute) + + async def async_cleanup(self) -> None: + """Cleanup any pending task.""" + self._async_cancel_timer() + if not self._task: + return + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error cleaning up task", exc_info=True) + + class MQTT: """Home Assistant MQTT client.""" _mqttc: mqtt.Client + _last_subscribe: float def __init__( self, @@ -316,12 +394,16 @@ class MQTT: self._wildcard_subscriptions: list[Subscription] = [] self.connected = False self._ha_started = asyncio.Event() - self._last_subscribe = time.time() self._cleanup_on_unload: list[Callable[[], None]] = [] self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client self._pending_operations: dict[int, asyncio.Event] = {} self._pending_operations_condition = asyncio.Condition() + self._subscribe_debouncer = EnsureJobAfterCooldown( + INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions + ) + self._max_qos: dict[str, int] = {} # topic, max qos + self._pending_subscriptions: dict[str, int] = {} # topic, qos if self.hass.state == CoreState.running: self._ha_started.set() @@ -442,6 +524,11 @@ class MQTT: """Return False if there are unprocessed ACKs.""" return not any(not op.is_set() for op in self._pending_operations.values()) + # stop waiting for any pending subscriptions + await self._subscribe_debouncer.async_cleanup() + # reset timeout to initial subscribe cooldown + self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN) + # wait for ACKs to be processed async with self._pending_operations_condition: await self._pending_operations_condition.wait_for(no_more_acks) @@ -494,6 +581,20 @@ class MQTT: except (KeyError, ValueError) as ex: raise HomeAssistantError("Can't remove subscription twice") from ex + @callback + def _async_queue_subscriptions( + self, subscriptions: Iterable[tuple[str, int]], queue_only: bool = False + ) -> None: + """Queue requested subscriptions.""" + for subscription in subscriptions: + topic, qos = subscription + max_qos = max(qos, self._max_qos.setdefault(topic, qos)) + self._max_qos[topic] = max_qos + self._pending_subscriptions[topic] = max_qos + if queue_only: + return + self._subscribe_debouncer.async_schedule() + async def async_subscribe( self, topic: str, @@ -516,15 +617,13 @@ class MQTT: # Only subscribe if currently connected. if self.connected: - self._last_subscribe = time.time() - await self._async_perform_subscriptions(((topic, qos),)) + self._async_queue_subscriptions(((topic, qos),)) @callback def async_remove() -> None: """Remove subscription.""" self._async_untrack_subscription(subscription) self._matching_subscriptions.cache_clear() - # Only unsubscribe if currently connected if self.connected: self.hass.async_create_task(self._async_unsubscribe(topic)) @@ -543,21 +642,27 @@ class MQTT: _raise_on_error(result) return mid - async with self._paho_lock: - if self._is_active_subscription(topic): - # Other subscriptions on topic remaining - don't unsubscribe. + if self._is_active_subscription(topic): + if self._max_qos[topic] == 0: return - + subs = self._matching_subscriptions(topic) + self._max_qos[topic] = max(sub.qos for sub in subs) + # Other subscriptions on topic remaining - don't unsubscribe. + return + if topic in self._max_qos: + del self._max_qos[topic] + if topic in self._pending_subscriptions: + # avoid any pending subscription to be executed + del self._pending_subscriptions[topic] + async with self._paho_lock: mid = await self.hass.async_add_executor_job(_client_unsubscribe, topic) await self._register_mid(mid) self.hass.async_create_task(self._wait_for_mid(mid)) - async def _async_perform_subscriptions( - self, subscriptions: Iterable[tuple[str, int]] - ) -> None: + async def _async_perform_subscriptions(self) -> None: """Perform MQTT client subscriptions.""" - + subscriptions: dict[str, int] # Section 3.3.1.3 in the specification: # http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html # When sending a PUBLISH Packet to a Client the Server MUST @@ -573,16 +678,20 @@ class MQTT: def _process_client_subscriptions() -> list[tuple[int, int]]: """Initiate all subscriptions on the MQTT client and return the results.""" subscribe_result_list = [] - for topic, qos in subscriptions: + for topic, qos in subscriptions.items(): result, mid = self._mqttc.subscribe(topic, qos) subscribe_result_list.append((result, mid)) _LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos) return subscribe_result_list + subscriptions = self._pending_subscriptions + self._pending_subscriptions = {} + async with self._paho_lock: results = await self.hass.async_add_executor_job( _process_client_subscriptions ) + self._last_subscribe = time.time() tasks: list[Coroutine[Any, Any, None]] = [] errors: list[int] = [] @@ -639,6 +748,8 @@ class MQTT: async def publish_birth_message(birth_message: PublishMessage) -> None: await self._ha_started.wait() # Wait for Home Assistant to start await self._discovery_cooldown() # Wait for MQTT discovery to cool down + # Update subscribe cooldown period to a shorter time + self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN) await self.async_publish( topic=birth_message.topic, payload=birth_message.payload, @@ -654,16 +765,19 @@ class MQTT: async def _async_resubscribe(self) -> None: """Resubscribe on reconnect.""" # Group subscriptions to only re-subscribe once for each topic. + self._max_qos.clear() keyfunc = attrgetter("topic") - await self._async_perform_subscriptions( + self._async_queue_subscriptions( [ # Re-subscribe with the highest requested qos (topic, max(subscription.qos for subscription in subs)) for topic, subs in groupby( sorted(self.subscriptions, key=keyfunc), keyfunc ) - ] + ], + queue_only=True, ) + await self._async_perform_subscriptions() def _mqtt_on_message( self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage @@ -785,13 +899,14 @@ class MQTT: self._pending_operations_condition.notify_all() async def _discovery_cooldown(self) -> None: + """Wait until all discovery and subscriptions are processed.""" now = time.time() # Reset discovery and subscribe cooldowns self._mqtt_data.last_discovery = now self._last_subscribe = now last_discovery = self._mqtt_data.last_discovery - last_subscribe = self._last_subscribe + last_subscribe = now if self._pending_subscriptions else self._last_subscribe wait_until = max( last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN ) @@ -799,7 +914,9 @@ class MQTT: await asyncio.sleep(wait_until - now) now = time.time() last_discovery = self._mqtt_data.last_discovery - last_subscribe = self._last_subscribe + last_subscribe = ( + now if self._pending_subscriptions else self._last_subscribe + ) wait_until = max( last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN ) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 5cd615e0eb6..a21f69544e8 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -1374,6 +1374,8 @@ async def test_complex_discovery_topic_prefix( @patch("homeassistant.components.mqtt.PLATFORMS", []) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) async def test_mqtt_integration_discovery_subscribe_unsubscribe( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1392,6 +1394,7 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( ): await async_start(hass, "homeassistant", entry) await hass.async_block_till_done() + await hass.async_block_till_done() mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) assert not mqtt_client_mock.unsubscribe.called @@ -1418,6 +1421,8 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( @patch("homeassistant.components.mqtt.PLATFORMS", []) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) async def test_mqtt_discovery_unsubscribe_once( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1436,6 +1441,7 @@ async def test_mqtt_discovery_unsubscribe_once( ): await async_start(hass, "homeassistant", entry) await hass.async_block_till_done() + await hass.async_block_till_done() mqtt_client_mock.subscribe.assert_any_call("comp/discovery/#", 0) assert not mqtt_client_mock.unsubscribe.called diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 8c6ee21c932..ec373aab0d7 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -17,6 +17,7 @@ import yaml from homeassistant import config as hass_config from homeassistant.components import mqtt from homeassistant.components.mqtt import CONFIG_SCHEMA, debug_info +from homeassistant.components.mqtt.client import EnsureJobAfterCooldown from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState @@ -1262,6 +1263,9 @@ async def test_subscribe_special_characters( assert calls[0].payload == payload +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) async def test_subscribe_same_topic( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1286,26 +1290,35 @@ async def test_subscribe_same_topic( def _callback_b(msg: ReceiveMessage) -> None: calls_b.append(msg) - await mqtt.async_subscribe(hass, "test/state", _callback_a) + await mqtt.async_subscribe(hass, "test/state", _callback_a, qos=0) async_fire_mqtt_message( hass, "test/state", "online" - ) # Simulate a (retained) message + ) # Simulate a (retained) message replaying + async_fire_time_changed(hass, utcnow() + timedelta(seconds=1)) await hass.async_block_till_done() assert len(calls_a) == 1 mqtt_client_mock.subscribe.assert_called() calls_a = [] mqtt_client_mock.reset_mock() - await mqtt.async_subscribe(hass, "test/state", _callback_b) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) + await hass.async_block_till_done() + await mqtt.async_subscribe(hass, "test/state", _callback_b, qos=1) async_fire_mqtt_message( hass, "test/state", "online" - ) # Simulate a (retained) message + ) # Simulate a (retained) message replaying + async_fire_time_changed(hass, utcnow() + timedelta(seconds=1)) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=1)) await hass.async_block_till_done() assert len(calls_a) == 1 assert len(calls_b) == 1 mqtt_client_mock.subscribe.assert_called() +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) async def test_not_calling_unsubscribe_with_active_subscribers( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1317,8 +1330,10 @@ async def test_not_calling_unsubscribe_with_active_subscribers( # Fake that the client is connected mqtt_mock().connected = True - unsub = await mqtt.async_subscribe(hass, "test/state", record_calls) - await mqtt.async_subscribe(hass, "test/state", record_calls) + unsub = await mqtt.async_subscribe(hass, "test/state", record_calls, 2) + await mqtt.async_subscribe(hass, "test/state", record_calls, 1) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown await hass.async_block_till_done() assert mqtt_client_mock.subscribe.called @@ -1327,6 +1342,30 @@ async def test_not_calling_unsubscribe_with_active_subscribers( assert not mqtt_client_mock.unsubscribe.called +async def test_not_calling_subscribe_when_unsubscribed_within_cooldown( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator, + record_calls: MessageCallbackType, +) -> None: + """Test not calling subscribe() when it is unsubscribed. + + Make sure subscriptions are cleared if unsubscribed before + the subscribe cool down period has ended. + """ + mqtt_mock = await mqtt_mock_entry_no_yaml_config() + # Fake that the client is connected + mqtt_mock().connected = True + + unsub = await mqtt.async_subscribe(hass, "test/state", record_calls) + unsub() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown + await hass.async_block_till_done() + assert not mqtt_client_mock.subscribe.called + + +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) async def test_unsubscribe_race( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1351,13 +1390,15 @@ async def test_unsubscribe_race( unsub() await mqtt.async_subscribe(hass, "test/state", _callback_b) await hass.async_block_till_done() + await hass.async_block_till_done() async_fire_mqtt_message(hass, "test/state", "online") await hass.async_block_till_done() assert not calls_a assert calls_b - # We allow either calls [subscribe, unsubscribe, subscribe] or [subscribe, subscribe] + # We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or + # when both subscriptions were combined [subscribe] expected_calls_1 = [ call.subscribe("test/state", 0), call.unsubscribe("test/state"), @@ -1367,13 +1408,23 @@ async def test_unsubscribe_race( call.subscribe("test/state", 0), call.subscribe("test/state", 0), ] - assert mqtt_client_mock.mock_calls in (expected_calls_1, expected_calls_2) + expected_calls_3 = [ + call.subscribe("test/state", 0), + ] + assert mqtt_client_mock.mock_calls in ( + expected_calls_1, + expected_calls_2, + expected_calls_3, + ) @pytest.mark.parametrize( "mqtt_config_entry_data", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], ) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) async def test_restore_subscriptions_on_reconnect( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1386,13 +1437,15 @@ async def test_restore_subscriptions_on_reconnect( mqtt_mock().connected = True await mqtt.async_subscribe(hass, "test/state", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown await hass.async_block_till_done() assert mqtt_client_mock.subscribe.call_count == 1 mqtt_client_mock.on_disconnect(None, None, 0) - with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0): - mqtt_client_mock.on_connect(None, None, None, 0) - await hass.async_block_till_done() + mqtt_client_mock.on_connect(None, None, None, 0) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown + await hass.async_block_till_done() + await hass.async_block_till_done() assert mqtt_client_mock.subscribe.call_count == 2 @@ -1400,6 +1453,9 @@ async def test_restore_subscriptions_on_reconnect( "mqtt_config_entry_data", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], ) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 1.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 1.0) async def test_restore_all_active_subscriptions_on_reconnect( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, @@ -1412,14 +1468,15 @@ async def test_restore_all_active_subscriptions_on_reconnect( mqtt_mock().connected = True unsub = await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2) - await mqtt.async_subscribe(hass, "test/state", record_calls) await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1) + await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown await hass.async_block_till_done() + # the subscribtion with the highest QoS should survive expected = [ call("test/state", 2), - call("test/state", 0), - call("test/state", 1), ] assert mqtt_client_mock.subscribe.mock_calls == expected @@ -1428,13 +1485,60 @@ async def test_restore_all_active_subscriptions_on_reconnect( assert mqtt_client_mock.unsubscribe.call_count == 0 mqtt_client_mock.on_disconnect(None, None, 0) - with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0): - mqtt_client_mock.on_connect(None, None, None, 0) - await hass.async_block_till_done() + await hass.async_block_till_done() + mqtt_client_mock.on_connect(None, None, None, 0) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown + await hass.async_block_till_done() expected.append(call("test/state", 1)) assert mqtt_client_mock.subscribe.mock_calls == expected + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=3)) # cooldown + await hass.async_block_till_done() + + +@pytest.mark.parametrize( + "mqtt_config_entry_data", + [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], +) +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 1.0) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 1.0) +async def test_subscribed_at_highest_qos( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator, + record_calls: MessageCallbackType, +) -> None: + """Test the highest qos as assigned when subscribing to the same topic.""" + mqtt_mock = await mqtt_mock_entry_no_yaml_config() + # Fake that the client is connected + mqtt_mock().connected = True + + await mqtt.async_subscribe(hass, "test/state", record_calls, qos=0) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown + await hass.async_block_till_done() + assert mqtt_client_mock.subscribe.mock_calls == [ + call("test/state", 0), + ] + mqtt_client_mock.reset_mock() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown + await hass.async_block_till_done() + await hass.async_block_till_done() + + await mqtt.async_subscribe(hass, "test/state", record_calls, qos=1) + await mqtt.async_subscribe(hass, "test/state", record_calls, qos=2) + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) # cooldown + await hass.async_block_till_done() + # the subscribtion with the highest QoS should survive + assert mqtt_client_mock.subscribe.mock_calls == [ + call("test/state", 2), + ] + async def test_reload_entry_with_restored_subscriptions( hass: HomeAssistant, @@ -1499,6 +1603,93 @@ async def test_reload_entry_with_restored_subscriptions( assert calls[1].payload == "wild-card-payload3" +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 2) +@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 2) +@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 2) +async def test_canceling_debouncer_on_shutdown( + hass: HomeAssistant, + record_calls: MessageCallbackType, + mqtt_client_mock: MqttMockPahoClient, + mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator, +) -> None: + """Test canceling the debouncer when HA shuts down.""" + + mqtt_mock = await mqtt_mock_entry_no_yaml_config() + + # Fake that the client is connected + mqtt_mock().connected = True + + await mqtt.async_subscribe(hass, "test/state1", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2)) + await hass.async_block_till_done() + + await mqtt.async_subscribe(hass, "test/state2", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2)) + await hass.async_block_till_done() + + await mqtt.async_subscribe(hass, "test/state3", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2)) + await hass.async_block_till_done() + + await mqtt.async_subscribe(hass, "test/state4", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.2)) + await hass.async_block_till_done() + + await mqtt.async_subscribe(hass, "test/state5", record_calls) + + mqtt_client_mock.subscribe.assert_not_called() + + # Stop HA so the scheduled task will be canceled + hass.bus.fire(EVENT_HOMEASSISTANT_STOP) + # mock disconnect status + mqtt_client_mock.on_disconnect(None, None, 0) + await hass.async_block_till_done() + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + mqtt_client_mock.subscribe.assert_not_called() + + +async def test_canceling_debouncer_normal( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test canceling the debouncer before completion.""" + + async def _async_myjob() -> None: + await asyncio.sleep(1.0) + + debouncer = EnsureJobAfterCooldown(0.0, _async_myjob) + debouncer.async_schedule() + await asyncio.sleep(0.01) + assert debouncer._task is not None + await debouncer.async_cleanup() + assert debouncer._task is None + + +async def test_canceling_debouncer_throws( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test canceling the debouncer when HA shuts down.""" + + async def _async_myjob() -> None: + await asyncio.sleep(1.0) + + debouncer = EnsureJobAfterCooldown(0.0, _async_myjob) + debouncer.async_schedule() + await asyncio.sleep(0.01) + assert debouncer._task is not None + # let debouncer._task fail by mocking it + with patch.object(debouncer, "_task") as task: + task.cancel = MagicMock(return_value=True) + await debouncer.async_cleanup() + assert "Error cleaning up task" in caplog.text + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + + async def test_initial_setup_logs_error( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, @@ -1575,21 +1766,30 @@ async def test_publish_error( assert "Failed to connect to MQTT server: Out of memory." in caplog.text +@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) async def test_subscribe_error( hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator, mqtt_client_mock: MqttMockPahoClient, record_calls: MessageCallbackType, + caplog: pytest.LogCaptureFixture, ) -> None: """Test publish error.""" await mqtt_mock_entry_no_yaml_config() mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0) await hass.async_block_till_done() - with pytest.raises(HomeAssistantError): - # simulate client is not connected error before subscribing - mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None) - await mqtt.async_subscribe(hass, "some-topic", record_calls) + await hass.async_block_till_done() + mqtt_client_mock.reset_mock() + # simulate client is not connected error before subscribing + mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None) + await mqtt.async_subscribe(hass, "some-topic", record_calls) + while mqtt_client_mock.subscribe.call_count == 0: await hass.async_block_till_done() + await hass.async_block_till_done() + await hass.async_block_till_done() + assert ( + "Error talking to MQTT: The client is not currently connected." in caplog.text + ) async def test_handle_message_callback(