diff --git a/homeassistant/core.py b/homeassistant/core.py index 7268b7d8f24..7003b87ce67 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -38,6 +38,7 @@ from typing import ( ) from urllib.parse import urlparse +import async_timeout from typing_extensions import Self import voluptuous as vol import yarl @@ -711,6 +712,14 @@ class HomeAssistant: "Stopping Home Assistant before startup has completed may fail" ) + # Keep holding the reference to the tasks but do not allow them + # to block shutdown. Only tasks created after this point will + # be waited for. + running_tasks = self._tasks + # Avoid clearing here since we want the remove callbacks to fire + # and remove the tasks from the original set which is now running_tasks + self._tasks = set() + # Cancel all background tasks for task in self._background_tasks: self._tasks.add(task) @@ -749,6 +758,35 @@ class HomeAssistant: self.state = CoreState.not_running self.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE) + # Make a copy of running_tasks since a task can finish + # while we are awaiting canceled tasks to get their result + # which will result in the set size changing during iteration + for task in list(running_tasks): + if task.done(): + # Since we made a copy we need to check + # to see if the task finished while we + # were awaiting another task + continue + _LOGGER.warning( + "Task %s was still running after stage 2 shutdown; " + "Integrations should cancel non-critical tasks when receiving " + "the stop event to prevent delaying shutdown", + task, + ) + task.cancel() + try: + async with async_timeout.timeout(0.1): + await task + except asyncio.CancelledError: + pass + except asyncio.TimeoutError: + # Task may be shielded from cancellation. + _LOGGER.exception( + "Task %s could not be canceled during stage 3 shutdown", task + ) + except Exception as ex: # pylint: disable=broad-except + _LOGGER.exception("Task %s error during stage 3 shutdown: %s", task, ex) + # Prevent run_callback_threadsafe from scheduling any additional # callbacks in the event loop as callbacks created on the futures # it returns will never run after the final `self.async_block_till_done` diff --git a/tests/components/airvisual/test_config_flow.py b/tests/components/airvisual/test_config_flow.py index 81c9fb81868..b07a17972f7 100644 --- a/tests/components/airvisual/test_config_flow.py +++ b/tests/components/airvisual/test_config_flow.py @@ -166,3 +166,4 @@ async def test_step_reauth( assert len(hass.config_entries.async_entries()) == 1 assert hass.config_entries.async_entries()[0].data[CONF_API_KEY] == new_api_key + await hass.async_block_till_done() diff --git a/tests/test_core.py b/tests/test_core.py index 4749daa0c0b..eb81efae920 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,6 +9,7 @@ import gc import logging import os from tempfile import TemporaryDirectory +import time from typing import Any from unittest.mock import MagicMock, Mock, PropertyMock, patch @@ -2003,3 +2004,49 @@ async def test_background_task(hass: HomeAssistant) -> None: await asyncio.sleep(0) await hass.async_stop() assert result.result() == ha.CoreState.stopping + + +async def test_shutdown_does_not_block_on_normal_tasks( + hass: HomeAssistant, +) -> None: + """Ensure shutdown does not block on normal tasks.""" + result = asyncio.Future() + unshielded_task = asyncio.sleep(10) + + async def test_task(): + try: + await unshielded_task + except asyncio.CancelledError: + result.set_result(hass.state) + + start = time.monotonic() + task = hass.async_create_task(test_task()) + await asyncio.sleep(0) + await hass.async_stop() + await asyncio.sleep(0) + assert result.done() + assert task.done() + assert time.monotonic() - start < 0.5 + + +async def test_shutdown_does_not_block_on_shielded_tasks( + hass: HomeAssistant, +) -> None: + """Ensure shutdown does not block on shielded tasks.""" + result = asyncio.Future() + shielded_task = asyncio.shield(asyncio.sleep(10)) + + async def test_task(): + try: + await shielded_task + except asyncio.CancelledError: + result.set_result(hass.state) + + start = time.monotonic() + task = hass.async_create_task(test_task()) + await asyncio.sleep(0) + await hass.async_stop() + await asyncio.sleep(0) + assert result.done() + assert task.done() + assert time.monotonic() - start < 0.5