diff --git a/homeassistant/components/yeelight/scanner.py b/homeassistant/components/yeelight/scanner.py index c98ca625029..6ca12e9bd01 100644 --- a/homeassistant/components/yeelight/scanner.py +++ b/homeassistant/components/yeelight/scanner.py @@ -3,9 +3,10 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, ValuesView +from collections.abc import ValuesView import contextlib from datetime import datetime +from functools import partial from ipaddress import IPv4Address import logging from typing import Self @@ -19,6 +20,7 @@ from homeassistant.components import network, ssdp from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.helpers import discovery_flow from homeassistant.helpers.event import async_call_later, async_track_time_interval +from homeassistant.util.async_ import create_eager_task from .const import ( DISCOVERY_ATTEMPTS, @@ -33,6 +35,12 @@ from .const import ( _LOGGER = logging.getLogger(__name__) +@callback +def _set_future_if_not_done(future: asyncio.Future[None]) -> None: + if not future.done(): + future.set_result(None) + + class YeelightScanner: """Scan for Yeelight devices.""" @@ -54,26 +62,18 @@ class YeelightScanner: self._host_capabilities: dict[str, CaseInsensitiveDict] = {} self._track_interval: CALLBACK_TYPE | None = None self._listeners: list[SsdpSearchListener] = [] - self._connected_events: list[asyncio.Event] = [] + self._setup_future: asyncio.Future[None] | None = None async def async_setup(self) -> None: """Set up the scanner.""" - if self._connected_events: - await self._async_wait_connected() - return - - for idx, source_ip in enumerate(await self._async_build_source_set()): - self._connected_events.append(asyncio.Event()) - - def _wrap_async_connected_idx(idx) -> Callable[[], None]: - """Create a function to capture the idx cell variable.""" - - @callback - def _async_connected() -> None: - self._connected_events[idx].set() - - return _async_connected + if self._setup_future is not None: + return await self._setup_future + self._setup_future = self._hass.loop.create_future() + connected_futures: list[asyncio.Future[None]] = [] + for source_ip in await self._async_build_source_set(): + future = self._hass.loop.create_future() + connected_futures.append(future) source = (str(source_ip), 0) self._listeners.append( SsdpSearchListener( @@ -81,12 +81,15 @@ class YeelightScanner: search_target=SSDP_ST, target=SSDP_TARGET, source=source, - connect_callback=_wrap_async_connected_idx(idx), + connect_callback=partial(_set_future_if_not_done, future), ) ) results = await asyncio.gather( - *(listener.async_start() for listener in self._listeners), + *( + create_eager_task(listener.async_start()) + for listener in self._listeners + ), return_exceptions=True, ) failed_listeners = [] @@ -99,20 +102,17 @@ class YeelightScanner: result, ) failed_listeners.append(self._listeners[idx]) - self._connected_events[idx].set() + _set_future_if_not_done(connected_futures[idx]) for listener in failed_listeners: self._listeners.remove(listener) - await self._async_wait_connected() + await asyncio.wait(connected_futures) self._track_interval = async_track_time_interval( self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True ) self.async_scan() - - async def _async_wait_connected(self): - """Wait for the listeners to be up and connected.""" - await asyncio.gather(*(event.wait() for event in self._connected_events)) + _set_future_if_not_done(self._setup_future) async def _async_build_source_set(self) -> set[IPv4Address]: """Build the list of ssdp sources.""" diff --git a/tests/components/yeelight/test_light.py b/tests/components/yeelight/test_light.py index 052b6d3223a..ff80c2b55b2 100644 --- a/tests/components/yeelight/test_light.py +++ b/tests/components/yeelight/test_light.py @@ -1413,6 +1413,7 @@ async def test_effects(hass: HomeAssistant) -> None: } }, ) + await hass.async_block_till_done() config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA) config_entry.add_to_hass(hass)