Track tasks adding entities (#73828)

* Track tasks adding entities

* Update homeassistant/config_entries.py

* fix cast tests

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2022-06-29 09:38:35 +02:00 committed by GitHub
parent 90c68085be
commit 00810235c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 8 deletions

View file

@ -73,6 +73,7 @@ PATH_CONFIG = ".config_entries.json"
SAVE_DELAY = 1
_T = TypeVar("_T", bound="ConfigEntryState")
_R = TypeVar("_R")
class ConfigEntryState(Enum):
@ -193,6 +194,7 @@ class ConfigEntry:
"_async_cancel_retry_setup",
"_on_unload",
"reload_lock",
"_pending_tasks",
)
def __init__(
@ -285,6 +287,8 @@ class ConfigEntry:
# Reload lock to prevent conflicting reloads
self.reload_lock = asyncio.Lock()
self._pending_tasks: list[asyncio.Future[Any]] = []
async def async_setup(
self,
hass: HomeAssistant,
@ -366,7 +370,7 @@ class ConfigEntry:
self.domain,
auth_message,
)
self._async_process_on_unload()
await self._async_process_on_unload()
self.async_start_reauth(hass)
result = False
except ConfigEntryNotReady as ex:
@ -406,7 +410,7 @@ class ConfigEntry:
EVENT_HOMEASSISTANT_STARTED, setup_again
)
self._async_process_on_unload()
await self._async_process_on_unload()
return
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
@ -494,7 +498,7 @@ class ConfigEntry:
self.state = ConfigEntryState.NOT_LOADED
self.reason = None
self._async_process_on_unload()
await self._async_process_on_unload()
# https://github.com/python/mypy/issues/11839
return result # type: ignore[no-any-return]
@ -619,13 +623,18 @@ class ConfigEntry:
self._on_unload = []
self._on_unload.append(func)
@callback
def _async_process_on_unload(self) -> None:
"""Process the on_unload callbacks."""
async def _async_process_on_unload(self) -> None:
"""Process the on_unload callbacks and wait for pending tasks."""
if self._on_unload is not None:
while self._on_unload:
self._on_unload.pop()()
while self._pending_tasks:
pending = [task for task in self._pending_tasks if not task.done()]
self._pending_tasks.clear()
if pending:
await asyncio.gather(*pending)
@callback
def async_start_reauth(self, hass: HomeAssistant) -> None:
"""Start a reauth flow."""
@ -648,6 +657,22 @@ class ConfigEntry:
)
)
@callback
def async_create_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]
) -> asyncio.Task[_R]:
"""Create a task from within the eventloop.
This method must be run in the event loop.
target: target to call.
"""
task = hass.async_create_task(target)
self._pending_tasks.append(task)
return task
current_entry: ContextVar[ConfigEntry | None] = ContextVar(
"current_entry", default=None

View file

@ -214,8 +214,9 @@ class EntityPlatform:
def async_create_setup_task() -> Coroutine:
"""Get task to set up platform."""
config_entries.current_entry.set(config_entry)
return platform.async_setup_entry( # type: ignore[no-any-return,union-attr]
self.hass, config_entry, self._async_schedule_add_entities
self.hass, config_entry, self._async_schedule_add_entities_for_entry
)
return await self._async_setup_platform(async_create_setup_task)
@ -334,6 +335,20 @@ class EntityPlatform:
if not self._setup_complete:
self._tasks.append(task)
@callback
def _async_schedule_add_entities_for_entry(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform async and track the task."""
assert self.config_entry
task = self.config_entry.async_create_task(
self.hass,
self.async_add_entities(new_entities, update_before_add=update_before_add),
)
if not self._setup_complete:
self._tasks.append(task)
def add_entities(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:

View file

@ -127,7 +127,7 @@ async def async_setup_cast(hass, config=None):
config = {}
data = {**{"ignore_cec": [], "known_hosts": [], "uuid": []}, **config}
with patch(
"homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities"
"homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities_for_entry"
) as add_entities:
entry = MockConfigEntry(data=data, domain="cast")
entry.add_to_hass(hass)