Avoid scheduling a task to add each entity when not using update_before_add (#110951)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2024-02-23 10:49:26 -10:00 committed by GitHub
parent 3aecec5082
commit d9addc45f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 194 additions and 59 deletions

View file

@ -346,11 +346,11 @@ class EntityPlatform:
# Block till all entities are done
while self._tasks:
pending = [task for task in self._tasks if not task.done()]
# Await all tasks even if they are done
# to ensure exceptions are propagated
pending = self._tasks.copy()
self._tasks.clear()
if pending:
await asyncio.gather(*pending)
await asyncio.gather(*pending)
hass.config.components.add(full_name)
self._setup_complete = True
@ -505,6 +505,82 @@ class EntityPlatform:
self.hass.loop,
).result()
async def _async_add_and_update_entities(
self,
coros: list[Coroutine[Any, Any, None]],
entities: list[Entity],
timeout: float,
) -> None:
"""Add entities for a single platform and update them.
Since we are updating the entities before adding them, we need to
schedule the coroutines as tasks so we can await them in the event
loop. This is because the update is likely to yield control to the
event loop and will finish faster if we run them concurrently.
"""
results: list[BaseException | None] | None = None
try:
async with self.hass.timeout.async_timeout(timeout, self.domain):
results = await asyncio.gather(*coros, return_exceptions=True)
except TimeoutError:
self.logger.warning(
"Timed out adding entities for domain %s with platform %s after %ds",
self.domain,
self.platform_name,
timeout,
)
if not results:
return
for idx, result in enumerate(results):
if isinstance(result, Exception):
entity = entities[idx]
self.logger.exception(
"Error adding entity %s for domain %s with platform %s",
entity.entity_id,
self.domain,
self.platform_name,
exc_info=result,
)
elif isinstance(result, BaseException):
raise result
async def _async_add_entities(
self,
coros: list[Coroutine[Any, Any, None]],
entities: list[Entity],
timeout: float,
) -> None:
"""Add entities for a single platform without updating.
In this case we are not updating the entities before adding them
which means its unlikely that we will not have to yield control
to the event loop so we can await the coros directly without
scheduling them as tasks.
"""
try:
async with self.hass.timeout.async_timeout(timeout, self.domain):
for idx, coro in enumerate(coros):
try:
await coro
except Exception as ex: # pylint: disable=broad-except
entity = entities[idx]
self.logger.exception(
"Error adding entity %s for domain %s with platform %s",
entity.entity_id,
self.domain,
self.platform_name,
exc_info=ex,
)
except TimeoutError:
self.logger.warning(
"Timed out adding entities for domain %s with platform %s after %ds",
self.domain,
self.platform_name,
timeout,
)
async def async_add_entities(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
@ -517,40 +593,31 @@ class EntityPlatform:
return
hass = self.hass
entity_registry = ent_reg.async_get(hass)
tasks = [
self._async_add_entity(entity, update_before_add, entity_registry)
for entity in new_entities
]
coros: list[Coroutine[Any, Any, None]] = []
entities: list[Entity] = []
for entity in new_entities:
coros.append(
self._async_add_entity(entity, update_before_add, entity_registry)
)
entities.append(entity)
# No entities for processing
if not tasks:
if not coros:
return
timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(tasks), SLOW_ADD_MIN_TIMEOUT)
try:
async with self.hass.timeout.async_timeout(timeout, self.domain):
await asyncio.gather(*tasks)
except TimeoutError:
self.logger.warning(
"Timed out adding entities for domain %s with platform %s after %ds",
self.domain,
self.platform_name,
timeout,
)
except Exception:
self.logger.exception(
"Error adding entities for domain %s with platform %s",
self.domain,
self.platform_name,
)
raise
timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(coros), SLOW_ADD_MIN_TIMEOUT)
if update_before_add:
add_func = self._async_add_and_update_entities
else:
add_func = self._async_add_entities
await add_func(coros, entities, timeout)
if (
(self.config_entry and self.config_entry.pref_disable_polling)
or self._async_unsub_polling is not None
or not any(entity.should_poll for entity in self.entities.values())
or not any(entity.should_poll for entity in entities)
):
return

View file

@ -150,6 +150,7 @@ async def test_state_changed_event_sends_message(
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id="1234")
await platform.async_add_entities([entity])
await hass.async_block_till_done()
mqtt_mock.async_publish.assert_called_with(
"pub/test_domain/test_platform_1234/state", "unknown", 1, True

View file

@ -1992,7 +1992,7 @@ async def test_non_numeric_validation_raise(
state = hass.states.get(entity0.entity_id)
assert state is None
assert ("Error adding entities for domain sensor with platform test") in caplog.text
assert ("for domain sensor with platform test") in caplog.text
@pytest.mark.parametrize(

View file

@ -1747,22 +1747,26 @@ async def test_suggest_report_issue_custom_component(
assert suggestion == "create a bug report at https://some_url"
async def test_reuse_entity_object_after_abort(hass: HomeAssistant) -> None:
async def test_reuse_entity_object_after_abort(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test reuse entity object."""
platform = MockEntityPlatform(hass, domain="test")
ent = entity.Entity()
ent.entity_id = "invalid"
with pytest.raises(HomeAssistantError, match="Invalid entity ID: invalid"):
await platform.async_add_entities([ent])
with pytest.raises(
HomeAssistantError,
match="Entity 'invalid' cannot be added a second time to an entity platform",
):
await platform.async_add_entities([ent])
await platform.async_add_entities([ent])
assert "Invalid entity ID: invalid" in caplog.text
await platform.async_add_entities([ent])
assert (
"Entity 'invalid' cannot be added a second time to an entity platform"
in caplog.text
)
async def test_reuse_entity_object_after_entity_registry_remove(
hass: HomeAssistant, entity_registry: er.EntityRegistry
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test reuse entity object."""
entry = entity_registry.async_get_or_create("test", "test", "5678")
@ -1777,15 +1781,15 @@ async def test_reuse_entity_object_after_entity_registry_remove(
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids()) == 0
with pytest.raises(
HomeAssistantError,
match="Entity 'test.test_5678' cannot be added a second time",
):
await platform.async_add_entities([ent])
await platform.async_add_entities([ent])
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
assert len(hass.states.async_entity_ids()) == 0
async def test_reuse_entity_object_after_entity_registry_disabled(
hass: HomeAssistant, entity_registry: er.EntityRegistry
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test reuse entity object."""
entry = entity_registry.async_get_or_create("test", "test", "5678")
@ -1802,11 +1806,9 @@ async def test_reuse_entity_object_after_entity_registry_disabled(
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids()) == 0
with pytest.raises(
HomeAssistantError,
match="Entity 'test.test_5678' cannot be added a second time",
):
await platform.async_add_entities([ent])
await platform.async_add_entities([ent])
assert len(hass.states.async_entity_ids()) == 0
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
async def test_change_entity_id(

View file

@ -1710,14 +1710,23 @@ async def test_register_entity_service_limited_to_matching_platforms(
}
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
@pytest.mark.parametrize("update_before_add", (True, False))
async def test_invalid_entity_id(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, update_before_add: bool
) -> None:
"""Test specifying an invalid entity id."""
platform = MockEntityPlatform(hass)
entity = MockEntity(entity_id="invalid_entity_id")
with pytest.raises(HomeAssistantError):
await platform.async_add_entities([entity])
entity2 = MockEntity(entity_id="valid.entity_id")
await platform.async_add_entities(
[entity, entity2], update_before_add=update_before_add
)
assert entity.hass is None
assert entity.platform is None
assert "Invalid entity ID: invalid_entity_id" in caplog.text
# Ensure the valid entity was still added
assert entity2.hass is not None
assert entity2.platform is not None
class MockBlockingEntity(MockEntity):
@ -1728,16 +1737,21 @@ class MockBlockingEntity(MockEntity):
await asyncio.sleep(1000)
@pytest.mark.parametrize("update_before_add", (True, False))
async def test_setup_entry_with_entities_that_block_forever(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
entity_registry: er.EntityRegistry,
update_before_add: bool,
) -> None:
"""Test we cancel adding entities when we reach the timeout."""
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities([MockBlockingEntity(name="test1", unique_id="unique")])
async_add_entities(
[MockBlockingEntity(name="test1", unique_id="unique")],
update_before_add=update_before_add,
)
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
@ -1761,7 +1775,47 @@ async def test_setup_entry_with_entities_that_block_forever(
assert "test" in caplog.text
async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None:
class MockCancellingEntity(MockEntity):
"""Class to mock an entity get cancelled while adding."""
async def async_added_to_hass(self):
"""Mock cancellation."""
raise asyncio.CancelledError
@pytest.mark.parametrize("update_before_add", (True, False))
async def test_cancellation_is_not_blocked(
hass: HomeAssistant,
update_before_add: bool,
) -> None:
"""Test cancellation is not blocked while adding entities."""
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities(
[MockCancellingEntity(name="test1", unique_id="unique")],
update_before_add=update_before_add,
)
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(entry_id="super-mock-id")
platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)
with pytest.raises(asyncio.CancelledError):
assert await platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
full_name = f"{config_entry.domain}.{platform.domain}"
assert full_name not in hass.config.components
@pytest.mark.parametrize("update_before_add", (True, False))
async def test_two_platforms_add_same_entity(
hass: HomeAssistant, update_before_add: bool
) -> None:
"""Test two platforms in the same domain adding an entity with the same name."""
entity_platform1 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
@ -1774,8 +1828,12 @@ async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None:
entity2 = SlowEntity(name="entity_1")
await asyncio.gather(
entity_platform1.async_add_entities([entity1]),
entity_platform2.async_add_entities([entity2]),
entity_platform1.async_add_entities(
[entity1], update_before_add=update_before_add
),
entity_platform2.async_add_entities(
[entity2], update_before_add=update_before_add
),
)
entities = []
@ -1816,12 +1874,14 @@ class SlowEntity(MockEntity):
(True, None, "test_domain.device_bla"),
),
)
@pytest.mark.parametrize("update_before_add", (True, False))
async def test_entity_name_influences_entity_id(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
has_entity_name: bool,
entity_name: str | None,
expected_entity_id: str,
update_before_add: bool,
) -> None:
"""Test entity_id is influenced by entity name."""
@ -1839,7 +1899,8 @@ async def test_entity_name_influences_entity_id(
has_entity_name=has_entity_name,
name=entity_name,
),
]
],
update_before_add=update_before_add,
)
return True
@ -1867,12 +1928,14 @@ async def test_entity_name_influences_entity_id(
("cn", True, "test_domain.device_bla_english_name"),
),
)
@pytest.mark.parametrize("update_before_add", (True, False))
async def test_translated_entity_name_influences_entity_id(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
language: str,
has_entity_name: bool,
expected_entity_id: str,
update_before_add: bool,
) -> None:
"""Test entity_id is influenced by translated entity name."""
@ -1909,7 +1972,9 @@ async def test_translated_entity_name_influences_entity_id(
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities([TranslatedEntity(has_entity_name)])
async_add_entities(
[TranslatedEntity(has_entity_name)], update_before_add=update_before_add
)
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)