Track tasks adding entities (#73828)
* Track tasks adding entities * Update homeassistant/config_entries.py * fix cast tests Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
90c68085be
commit
00810235c9
3 changed files with 48 additions and 8 deletions
|
@ -73,6 +73,7 @@ PATH_CONFIG = ".config_entries.json"
|
|||
SAVE_DELAY = 1
|
||||
|
||||
_T = TypeVar("_T", bound="ConfigEntryState")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
class ConfigEntryState(Enum):
|
||||
|
@ -193,6 +194,7 @@ class ConfigEntry:
|
|||
"_async_cancel_retry_setup",
|
||||
"_on_unload",
|
||||
"reload_lock",
|
||||
"_pending_tasks",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -285,6 +287,8 @@ class ConfigEntry:
|
|||
# Reload lock to prevent conflicting reloads
|
||||
self.reload_lock = asyncio.Lock()
|
||||
|
||||
self._pending_tasks: list[asyncio.Future[Any]] = []
|
||||
|
||||
async def async_setup(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
|
@ -366,7 +370,7 @@ class ConfigEntry:
|
|||
self.domain,
|
||||
auth_message,
|
||||
)
|
||||
self._async_process_on_unload()
|
||||
await self._async_process_on_unload()
|
||||
self.async_start_reauth(hass)
|
||||
result = False
|
||||
except ConfigEntryNotReady as ex:
|
||||
|
@ -406,7 +410,7 @@ class ConfigEntry:
|
|||
EVENT_HOMEASSISTANT_STARTED, setup_again
|
||||
)
|
||||
|
||||
self._async_process_on_unload()
|
||||
await self._async_process_on_unload()
|
||||
return
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
|
@ -494,7 +498,7 @@ class ConfigEntry:
|
|||
self.state = ConfigEntryState.NOT_LOADED
|
||||
self.reason = None
|
||||
|
||||
self._async_process_on_unload()
|
||||
await self._async_process_on_unload()
|
||||
|
||||
# https://github.com/python/mypy/issues/11839
|
||||
return result # type: ignore[no-any-return]
|
||||
|
@ -619,13 +623,18 @@ class ConfigEntry:
|
|||
self._on_unload = []
|
||||
self._on_unload.append(func)
|
||||
|
||||
@callback
|
||||
def _async_process_on_unload(self) -> None:
|
||||
"""Process the on_unload callbacks."""
|
||||
async def _async_process_on_unload(self) -> None:
|
||||
"""Process the on_unload callbacks and wait for pending tasks."""
|
||||
if self._on_unload is not None:
|
||||
while self._on_unload:
|
||||
self._on_unload.pop()()
|
||||
|
||||
while self._pending_tasks:
|
||||
pending = [task for task in self._pending_tasks if not task.done()]
|
||||
self._pending_tasks.clear()
|
||||
if pending:
|
||||
await asyncio.gather(*pending)
|
||||
|
||||
@callback
|
||||
def async_start_reauth(self, hass: HomeAssistant) -> None:
|
||||
"""Start a reauth flow."""
|
||||
|
@ -648,6 +657,22 @@ class ConfigEntry:
|
|||
)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_create_task(
|
||||
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]
|
||||
) -> asyncio.Task[_R]:
|
||||
"""Create a task from within the eventloop.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
||||
target: target to call.
|
||||
"""
|
||||
task = hass.async_create_task(target)
|
||||
|
||||
self._pending_tasks.append(task)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
current_entry: ContextVar[ConfigEntry | None] = ContextVar(
|
||||
"current_entry", default=None
|
||||
|
|
|
@ -214,8 +214,9 @@ class EntityPlatform:
|
|||
def async_create_setup_task() -> Coroutine:
|
||||
"""Get task to set up platform."""
|
||||
config_entries.current_entry.set(config_entry)
|
||||
|
||||
return platform.async_setup_entry( # type: ignore[no-any-return,union-attr]
|
||||
self.hass, config_entry, self._async_schedule_add_entities
|
||||
self.hass, config_entry, self._async_schedule_add_entities_for_entry
|
||||
)
|
||||
|
||||
return await self._async_setup_platform(async_create_setup_task)
|
||||
|
@ -334,6 +335,20 @@ class EntityPlatform:
|
|||
if not self._setup_complete:
|
||||
self._tasks.append(task)
|
||||
|
||||
@callback
|
||||
def _async_schedule_add_entities_for_entry(
|
||||
self, new_entities: Iterable[Entity], update_before_add: bool = False
|
||||
) -> None:
|
||||
"""Schedule adding entities for a single platform async and track the task."""
|
||||
assert self.config_entry
|
||||
task = self.config_entry.async_create_task(
|
||||
self.hass,
|
||||
self.async_add_entities(new_entities, update_before_add=update_before_add),
|
||||
)
|
||||
|
||||
if not self._setup_complete:
|
||||
self._tasks.append(task)
|
||||
|
||||
def add_entities(
|
||||
self, new_entities: Iterable[Entity], update_before_add: bool = False
|
||||
) -> None:
|
||||
|
|
|
@ -127,7 +127,7 @@ async def async_setup_cast(hass, config=None):
|
|||
config = {}
|
||||
data = {**{"ignore_cec": [], "known_hosts": [], "uuid": []}, **config}
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities"
|
||||
"homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities_for_entry"
|
||||
) as add_entities:
|
||||
entry = MockConfigEntry(data=data, domain="cast")
|
||||
entry.add_to_hass(hass)
|
||||
|
|
Loading…
Add table
Reference in a new issue