Avoid scheduling a task to add each entity when not using update_before_add (#110951)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
3aecec5082
commit
d9addc45f9
5 changed files with 194 additions and 59 deletions
|
@ -346,11 +346,11 @@ class EntityPlatform:
|
|||
|
||||
# Block till all entities are done
|
||||
while self._tasks:
|
||||
pending = [task for task in self._tasks if not task.done()]
|
||||
# Await all tasks even if they are done
|
||||
# to ensure exceptions are propagated
|
||||
pending = self._tasks.copy()
|
||||
self._tasks.clear()
|
||||
|
||||
if pending:
|
||||
await asyncio.gather(*pending)
|
||||
await asyncio.gather(*pending)
|
||||
|
||||
hass.config.components.add(full_name)
|
||||
self._setup_complete = True
|
||||
|
@ -505,6 +505,82 @@ class EntityPlatform:
|
|||
self.hass.loop,
|
||||
).result()
|
||||
|
||||
async def _async_add_and_update_entities(
|
||||
self,
|
||||
coros: list[Coroutine[Any, Any, None]],
|
||||
entities: list[Entity],
|
||||
timeout: float,
|
||||
) -> None:
|
||||
"""Add entities for a single platform and update them.
|
||||
|
||||
Since we are updating the entities before adding them, we need to
|
||||
schedule the coroutines as tasks so we can await them in the event
|
||||
loop. This is because the update is likely to yield control to the
|
||||
event loop and will finish faster if we run them concurrently.
|
||||
"""
|
||||
results: list[BaseException | None] | None = None
|
||||
try:
|
||||
async with self.hass.timeout.async_timeout(timeout, self.domain):
|
||||
results = await asyncio.gather(*coros, return_exceptions=True)
|
||||
except TimeoutError:
|
||||
self.logger.warning(
|
||||
"Timed out adding entities for domain %s with platform %s after %ds",
|
||||
self.domain,
|
||||
self.platform_name,
|
||||
timeout,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
entity = entities[idx]
|
||||
self.logger.exception(
|
||||
"Error adding entity %s for domain %s with platform %s",
|
||||
entity.entity_id,
|
||||
self.domain,
|
||||
self.platform_name,
|
||||
exc_info=result,
|
||||
)
|
||||
elif isinstance(result, BaseException):
|
||||
raise result
|
||||
|
||||
async def _async_add_entities(
|
||||
self,
|
||||
coros: list[Coroutine[Any, Any, None]],
|
||||
entities: list[Entity],
|
||||
timeout: float,
|
||||
) -> None:
|
||||
"""Add entities for a single platform without updating.
|
||||
|
||||
In this case we are not updating the entities before adding them
|
||||
which means its unlikely that we will not have to yield control
|
||||
to the event loop so we can await the coros directly without
|
||||
scheduling them as tasks.
|
||||
"""
|
||||
try:
|
||||
async with self.hass.timeout.async_timeout(timeout, self.domain):
|
||||
for idx, coro in enumerate(coros):
|
||||
try:
|
||||
await coro
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
entity = entities[idx]
|
||||
self.logger.exception(
|
||||
"Error adding entity %s for domain %s with platform %s",
|
||||
entity.entity_id,
|
||||
self.domain,
|
||||
self.platform_name,
|
||||
exc_info=ex,
|
||||
)
|
||||
except TimeoutError:
|
||||
self.logger.warning(
|
||||
"Timed out adding entities for domain %s with platform %s after %ds",
|
||||
self.domain,
|
||||
self.platform_name,
|
||||
timeout,
|
||||
)
|
||||
|
||||
async def async_add_entities(
|
||||
self, new_entities: Iterable[Entity], update_before_add: bool = False
|
||||
) -> None:
|
||||
|
@ -517,40 +593,31 @@ class EntityPlatform:
|
|||
return
|
||||
|
||||
hass = self.hass
|
||||
|
||||
entity_registry = ent_reg.async_get(hass)
|
||||
tasks = [
|
||||
self._async_add_entity(entity, update_before_add, entity_registry)
|
||||
for entity in new_entities
|
||||
]
|
||||
coros: list[Coroutine[Any, Any, None]] = []
|
||||
entities: list[Entity] = []
|
||||
for entity in new_entities:
|
||||
coros.append(
|
||||
self._async_add_entity(entity, update_before_add, entity_registry)
|
||||
)
|
||||
entities.append(entity)
|
||||
|
||||
# No entities for processing
|
||||
if not tasks:
|
||||
if not coros:
|
||||
return
|
||||
|
||||
timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(tasks), SLOW_ADD_MIN_TIMEOUT)
|
||||
try:
|
||||
async with self.hass.timeout.async_timeout(timeout, self.domain):
|
||||
await asyncio.gather(*tasks)
|
||||
except TimeoutError:
|
||||
self.logger.warning(
|
||||
"Timed out adding entities for domain %s with platform %s after %ds",
|
||||
self.domain,
|
||||
self.platform_name,
|
||||
timeout,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.exception(
|
||||
"Error adding entities for domain %s with platform %s",
|
||||
self.domain,
|
||||
self.platform_name,
|
||||
)
|
||||
raise
|
||||
timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(coros), SLOW_ADD_MIN_TIMEOUT)
|
||||
if update_before_add:
|
||||
add_func = self._async_add_and_update_entities
|
||||
else:
|
||||
add_func = self._async_add_entities
|
||||
|
||||
await add_func(coros, entities, timeout)
|
||||
|
||||
if (
|
||||
(self.config_entry and self.config_entry.pref_disable_polling)
|
||||
or self._async_unsub_polling is not None
|
||||
or not any(entity.should_poll for entity in self.entities.values())
|
||||
or not any(entity.should_poll for entity in entities)
|
||||
):
|
||||
return
|
||||
|
||||
|
|
|
@ -150,6 +150,7 @@ async def test_state_changed_event_sends_message(
|
|||
platform = MockEntityPlatform(hass)
|
||||
entity = MockEntity(unique_id="1234")
|
||||
await platform.async_add_entities([entity])
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mqtt_mock.async_publish.assert_called_with(
|
||||
"pub/test_domain/test_platform_1234/state", "unknown", 1, True
|
||||
|
|
|
@ -1992,7 +1992,7 @@ async def test_non_numeric_validation_raise(
|
|||
state = hass.states.get(entity0.entity_id)
|
||||
assert state is None
|
||||
|
||||
assert ("Error adding entities for domain sensor with platform test") in caplog.text
|
||||
assert ("for domain sensor with platform test") in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -1747,22 +1747,26 @@ async def test_suggest_report_issue_custom_component(
|
|||
assert suggestion == "create a bug report at https://some_url"
|
||||
|
||||
|
||||
async def test_reuse_entity_object_after_abort(hass: HomeAssistant) -> None:
|
||||
async def test_reuse_entity_object_after_abort(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test reuse entity object."""
|
||||
platform = MockEntityPlatform(hass, domain="test")
|
||||
ent = entity.Entity()
|
||||
ent.entity_id = "invalid"
|
||||
with pytest.raises(HomeAssistantError, match="Invalid entity ID: invalid"):
|
||||
await platform.async_add_entities([ent])
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Entity 'invalid' cannot be added a second time to an entity platform",
|
||||
):
|
||||
await platform.async_add_entities([ent])
|
||||
await platform.async_add_entities([ent])
|
||||
assert "Invalid entity ID: invalid" in caplog.text
|
||||
await platform.async_add_entities([ent])
|
||||
assert (
|
||||
"Entity 'invalid' cannot be added a second time to an entity platform"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_reuse_entity_object_after_entity_registry_remove(
|
||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||
hass: HomeAssistant,
|
||||
entity_registry: er.EntityRegistry,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test reuse entity object."""
|
||||
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
||||
|
@ -1777,15 +1781,15 @@ async def test_reuse_entity_object_after_entity_registry_remove(
|
|||
await hass.async_block_till_done()
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Entity 'test.test_5678' cannot be added a second time",
|
||||
):
|
||||
await platform.async_add_entities([ent])
|
||||
await platform.async_add_entities([ent])
|
||||
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
||||
|
||||
async def test_reuse_entity_object_after_entity_registry_disabled(
|
||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||
hass: HomeAssistant,
|
||||
entity_registry: er.EntityRegistry,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test reuse entity object."""
|
||||
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
||||
|
@ -1802,11 +1806,9 @@ async def test_reuse_entity_object_after_entity_registry_disabled(
|
|||
await hass.async_block_till_done()
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Entity 'test.test_5678' cannot be added a second time",
|
||||
):
|
||||
await platform.async_add_entities([ent])
|
||||
await platform.async_add_entities([ent])
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
|
||||
|
||||
|
||||
async def test_change_entity_id(
|
||||
|
|
|
@ -1710,14 +1710,23 @@ async def test_register_entity_service_limited_to_matching_platforms(
|
|||
}
|
||||
|
||||
|
||||
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
|
||||
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||
async def test_invalid_entity_id(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, update_before_add: bool
|
||||
) -> None:
|
||||
"""Test specifying an invalid entity id."""
|
||||
platform = MockEntityPlatform(hass)
|
||||
entity = MockEntity(entity_id="invalid_entity_id")
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await platform.async_add_entities([entity])
|
||||
entity2 = MockEntity(entity_id="valid.entity_id")
|
||||
await platform.async_add_entities(
|
||||
[entity, entity2], update_before_add=update_before_add
|
||||
)
|
||||
assert entity.hass is None
|
||||
assert entity.platform is None
|
||||
assert "Invalid entity ID: invalid_entity_id" in caplog.text
|
||||
# Ensure the valid entity was still added
|
||||
assert entity2.hass is not None
|
||||
assert entity2.platform is not None
|
||||
|
||||
|
||||
class MockBlockingEntity(MockEntity):
|
||||
|
@ -1728,16 +1737,21 @@ class MockBlockingEntity(MockEntity):
|
|||
await asyncio.sleep(1000)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||
async def test_setup_entry_with_entities_that_block_forever(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
entity_registry: er.EntityRegistry,
|
||||
update_before_add: bool,
|
||||
) -> None:
|
||||
"""Test we cancel adding entities when we reach the timeout."""
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Mock setup entry method."""
|
||||
async_add_entities([MockBlockingEntity(name="test1", unique_id="unique")])
|
||||
async_add_entities(
|
||||
[MockBlockingEntity(name="test1", unique_id="unique")],
|
||||
update_before_add=update_before_add,
|
||||
)
|
||||
return True
|
||||
|
||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
|
@ -1761,7 +1775,47 @@ async def test_setup_entry_with_entities_that_block_forever(
|
|||
assert "test" in caplog.text
|
||||
|
||||
|
||||
async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None:
|
||||
class MockCancellingEntity(MockEntity):
|
||||
"""Class to mock an entity get cancelled while adding."""
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Mock cancellation."""
|
||||
raise asyncio.CancelledError
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||
async def test_cancellation_is_not_blocked(
|
||||
hass: HomeAssistant,
|
||||
update_before_add: bool,
|
||||
) -> None:
|
||||
"""Test cancellation is not blocked while adding entities."""
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Mock setup entry method."""
|
||||
async_add_entities(
|
||||
[MockCancellingEntity(name="test1", unique_id="unique")],
|
||||
update_before_add=update_before_add,
|
||||
)
|
||||
return True
|
||||
|
||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
||||
platform = MockEntityPlatform(
|
||||
hass, platform_name=config_entry.domain, platform=platform
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
assert await platform.async_setup_entry(config_entry)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
full_name = f"{config_entry.domain}.{platform.domain}"
|
||||
assert full_name not in hass.config.components
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||
async def test_two_platforms_add_same_entity(
|
||||
hass: HomeAssistant, update_before_add: bool
|
||||
) -> None:
|
||||
"""Test two platforms in the same domain adding an entity with the same name."""
|
||||
entity_platform1 = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
|
@ -1774,8 +1828,12 @@ async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None:
|
|||
entity2 = SlowEntity(name="entity_1")
|
||||
|
||||
await asyncio.gather(
|
||||
entity_platform1.async_add_entities([entity1]),
|
||||
entity_platform2.async_add_entities([entity2]),
|
||||
entity_platform1.async_add_entities(
|
||||
[entity1], update_before_add=update_before_add
|
||||
),
|
||||
entity_platform2.async_add_entities(
|
||||
[entity2], update_before_add=update_before_add
|
||||
),
|
||||
)
|
||||
|
||||
entities = []
|
||||
|
@ -1816,12 +1874,14 @@ class SlowEntity(MockEntity):
|
|||
(True, None, "test_domain.device_bla"),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||
async def test_entity_name_influences_entity_id(
|
||||
hass: HomeAssistant,
|
||||
entity_registry: er.EntityRegistry,
|
||||
has_entity_name: bool,
|
||||
entity_name: str | None,
|
||||
expected_entity_id: str,
|
||||
update_before_add: bool,
|
||||
) -> None:
|
||||
"""Test entity_id is influenced by entity name."""
|
||||
|
||||
|
@ -1839,7 +1899,8 @@ async def test_entity_name_influences_entity_id(
|
|||
has_entity_name=has_entity_name,
|
||||
name=entity_name,
|
||||
),
|
||||
]
|
||||
],
|
||||
update_before_add=update_before_add,
|
||||
)
|
||||
return True
|
||||
|
||||
|
@ -1867,12 +1928,14 @@ async def test_entity_name_influences_entity_id(
|
|||
("cn", True, "test_domain.device_bla_english_name"),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||
async def test_translated_entity_name_influences_entity_id(
|
||||
hass: HomeAssistant,
|
||||
entity_registry: er.EntityRegistry,
|
||||
language: str,
|
||||
has_entity_name: bool,
|
||||
expected_entity_id: str,
|
||||
update_before_add: bool,
|
||||
) -> None:
|
||||
"""Test entity_id is influenced by translated entity name."""
|
||||
|
||||
|
@ -1909,7 +1972,9 @@ async def test_translated_entity_name_influences_entity_id(
|
|||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Mock setup entry method."""
|
||||
async_add_entities([TranslatedEntity(has_entity_name)])
|
||||
async_add_entities(
|
||||
[TranslatedEntity(has_entity_name)], update_before_add=update_before_add
|
||||
)
|
||||
return True
|
||||
|
||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
|
|
Loading…
Add table
Reference in a new issue