diff --git a/homeassistant/util/async_.py b/homeassistant/util/async_.py index d010d8cb341..8e4c52b3414 100644 --- a/homeassistant/util/async_.py +++ b/homeassistant/util/async_.py @@ -5,22 +5,28 @@ from __future__ import annotations from asyncio import ( AbstractEventLoop, Future, + Queue, Semaphore, Task, TimerHandle, gather, get_running_loop, + timeout as async_timeout, ) -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine import concurrent.futures import logging import threading from typing import Any +from typing_extensions import TypeVar + _LOGGER = logging.getLogger(__name__) _SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe" +_DataT = TypeVar("_DataT", default=Any) + def create_eager_task[_T]( coro: Coroutine[Any, Any, _T], @@ -138,3 +144,19 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]: """Return a list of scheduled TimerHandles.""" handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001 return handles + + +async def queue_to_iterable( + queue: Queue[_DataT | None], timeout: float | None = None +) -> AsyncIterable[_DataT]: + """Stream items from a queue until None with an optional timeout per item.""" + if timeout is None: + while (item := await queue.get()) is not None: + yield item + else: + while True: + async with async_timeout(timeout): + item = await queue.get() + if item is None: + break + yield item diff --git a/tests/util/test_async.py b/tests/util/test_async.py index cda10b69c3f..878a31fff95 100644 --- a/tests/util/test_async.py +++ b/tests/util/test_async.py @@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None: timer_handle.cancel() timer_handle2.cancel() timer_handle3.cancel() + + +async def test_queue_to_iterable() -> None: + """Test queue_to_iterable.""" + queue: asyncio.Queue[int | None] = asyncio.Queue() + expected_items = list(range(10)) + + for i in expected_items: + await queue.put(i) + + # Will terminate the stream + await queue.put(None) + + actual_items = [item async for item in hasync.queue_to_iterable(queue)] + + assert expected_items == actual_items + + # Check timeout + assert queue.empty() + + # Time out on first item + async with asyncio.timeout(1): + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 + # Should time out very quickly + async for _item in hasync.queue_to_iterable(queue, timeout=0.01): + await asyncio.sleep(1) + + # Check timeout on second item + assert queue.empty() + await queue.put(12345) + + # Time out on second item + async with asyncio.timeout(1): + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 + # Should time out very quickly + async for item in hasync.queue_to_iterable(queue, timeout=0.01): + if item != 12345: + await asyncio.sleep(1) + + assert queue.empty()