diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 631933762d2..f2893302e1d 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -346,11 +346,11 @@ class EntityPlatform: # Block till all entities are done while self._tasks: - pending = [task for task in self._tasks if not task.done()] + # Await all tasks even if they are done + # to ensure exceptions are propagated + pending = self._tasks.copy() self._tasks.clear() - - if pending: - await asyncio.gather(*pending) + await asyncio.gather(*pending) hass.config.components.add(full_name) self._setup_complete = True @@ -505,6 +505,82 @@ class EntityPlatform: self.hass.loop, ).result() + async def _async_add_and_update_entities( + self, + coros: list[Coroutine[Any, Any, None]], + entities: list[Entity], + timeout: float, + ) -> None: + """Add entities for a single platform and update them. + + Since we are updating the entities before adding them, we need to + schedule the coroutines as tasks so we can await them in the event + loop. This is because the update is likely to yield control to the + event loop and will finish faster if we run them concurrently. + """ + results: list[BaseException | None] | None = None + try: + async with self.hass.timeout.async_timeout(timeout, self.domain): + results = await asyncio.gather(*coros, return_exceptions=True) + except TimeoutError: + self.logger.warning( + "Timed out adding entities for domain %s with platform %s after %ds", + self.domain, + self.platform_name, + timeout, + ) + + if not results: + return + + for idx, result in enumerate(results): + if isinstance(result, Exception): + entity = entities[idx] + self.logger.exception( + "Error adding entity %s for domain %s with platform %s", + entity.entity_id, + self.domain, + self.platform_name, + exc_info=result, + ) + elif isinstance(result, BaseException): + raise result + + async def _async_add_entities( + self, + coros: list[Coroutine[Any, Any, None]], + entities: list[Entity], + timeout: float, + ) -> None: + """Add entities for a single platform without updating. + + In this case we are not updating the entities before adding them + which means its unlikely that we will not have to yield control + to the event loop so we can await the coros directly without + scheduling them as tasks. + """ + try: + async with self.hass.timeout.async_timeout(timeout, self.domain): + for idx, coro in enumerate(coros): + try: + await coro + except Exception as ex: # pylint: disable=broad-except + entity = entities[idx] + self.logger.exception( + "Error adding entity %s for domain %s with platform %s", + entity.entity_id, + self.domain, + self.platform_name, + exc_info=ex, + ) + except TimeoutError: + self.logger.warning( + "Timed out adding entities for domain %s with platform %s after %ds", + self.domain, + self.platform_name, + timeout, + ) + async def async_add_entities( self, new_entities: Iterable[Entity], update_before_add: bool = False ) -> None: @@ -517,40 +593,31 @@ class EntityPlatform: return hass = self.hass - entity_registry = ent_reg.async_get(hass) - tasks = [ - self._async_add_entity(entity, update_before_add, entity_registry) - for entity in new_entities - ] + coros: list[Coroutine[Any, Any, None]] = [] + entities: list[Entity] = [] + for entity in new_entities: + coros.append( + self._async_add_entity(entity, update_before_add, entity_registry) + ) + entities.append(entity) # No entities for processing - if not tasks: + if not coros: return - timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(tasks), SLOW_ADD_MIN_TIMEOUT) - try: - async with self.hass.timeout.async_timeout(timeout, self.domain): - await asyncio.gather(*tasks) - except TimeoutError: - self.logger.warning( - "Timed out adding entities for domain %s with platform %s after %ds", - self.domain, - self.platform_name, - timeout, - ) - except Exception: - self.logger.exception( - "Error adding entities for domain %s with platform %s", - self.domain, - self.platform_name, - ) - raise + timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(coros), SLOW_ADD_MIN_TIMEOUT) + if update_before_add: + add_func = self._async_add_and_update_entities + else: + add_func = self._async_add_entities + + await add_func(coros, entities, timeout) if ( (self.config_entry and self.config_entry.pref_disable_polling) or self._async_unsub_polling is not None - or not any(entity.should_poll for entity in self.entities.values()) + or not any(entity.should_poll for entity in entities) ): return diff --git a/tests/components/mqtt_statestream/test_init.py b/tests/components/mqtt_statestream/test_init.py index c7bb9d4fcfa..c9e0334d9d9 100644 --- a/tests/components/mqtt_statestream/test_init.py +++ b/tests/components/mqtt_statestream/test_init.py @@ -150,6 +150,7 @@ async def test_state_changed_event_sends_message( platform = MockEntityPlatform(hass) entity = MockEntity(unique_id="1234") await platform.async_add_entities([entity]) + await hass.async_block_till_done() mqtt_mock.async_publish.assert_called_with( "pub/test_domain/test_platform_1234/state", "unknown", 1, True diff --git a/tests/components/sensor/test_init.py b/tests/components/sensor/test_init.py index a120ad8db78..52e1851833e 100644 --- a/tests/components/sensor/test_init.py +++ b/tests/components/sensor/test_init.py @@ -1992,7 +1992,7 @@ async def test_non_numeric_validation_raise( state = hass.states.get(entity0.entity_id) assert state is None - assert ("Error adding entities for domain sensor with platform test") in caplog.text + assert ("for domain sensor with platform test") in caplog.text @pytest.mark.parametrize( diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 4de38cc814d..e9950ec4dfc 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -1747,22 +1747,26 @@ async def test_suggest_report_issue_custom_component( assert suggestion == "create a bug report at https://some_url" -async def test_reuse_entity_object_after_abort(hass: HomeAssistant) -> None: +async def test_reuse_entity_object_after_abort( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: """Test reuse entity object.""" platform = MockEntityPlatform(hass, domain="test") ent = entity.Entity() ent.entity_id = "invalid" - with pytest.raises(HomeAssistantError, match="Invalid entity ID: invalid"): - await platform.async_add_entities([ent]) - with pytest.raises( - HomeAssistantError, - match="Entity 'invalid' cannot be added a second time to an entity platform", - ): - await platform.async_add_entities([ent]) + await platform.async_add_entities([ent]) + assert "Invalid entity ID: invalid" in caplog.text + await platform.async_add_entities([ent]) + assert ( + "Entity 'invalid' cannot be added a second time to an entity platform" + in caplog.text + ) async def test_reuse_entity_object_after_entity_registry_remove( - hass: HomeAssistant, entity_registry: er.EntityRegistry + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + caplog: pytest.LogCaptureFixture, ) -> None: """Test reuse entity object.""" entry = entity_registry.async_get_or_create("test", "test", "5678") @@ -1777,15 +1781,15 @@ async def test_reuse_entity_object_after_entity_registry_remove( await hass.async_block_till_done() assert len(hass.states.async_entity_ids()) == 0 - with pytest.raises( - HomeAssistantError, - match="Entity 'test.test_5678' cannot be added a second time", - ): - await platform.async_add_entities([ent]) + await platform.async_add_entities([ent]) + assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text + assert len(hass.states.async_entity_ids()) == 0 async def test_reuse_entity_object_after_entity_registry_disabled( - hass: HomeAssistant, entity_registry: er.EntityRegistry + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + caplog: pytest.LogCaptureFixture, ) -> None: """Test reuse entity object.""" entry = entity_registry.async_get_or_create("test", "test", "5678") @@ -1802,11 +1806,9 @@ async def test_reuse_entity_object_after_entity_registry_disabled( await hass.async_block_till_done() assert len(hass.states.async_entity_ids()) == 0 - with pytest.raises( - HomeAssistantError, - match="Entity 'test.test_5678' cannot be added a second time", - ): - await platform.async_add_entities([ent]) + await platform.async_add_entities([ent]) + assert len(hass.states.async_entity_ids()) == 0 + assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text async def test_change_entity_id( diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index dace6515a38..07ecd7844da 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -1710,14 +1710,23 @@ async def test_register_entity_service_limited_to_matching_platforms( } -async def test_invalid_entity_id(hass: HomeAssistant) -> None: +@pytest.mark.parametrize("update_before_add", (True, False)) +async def test_invalid_entity_id( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, update_before_add: bool +) -> None: """Test specifying an invalid entity id.""" platform = MockEntityPlatform(hass) entity = MockEntity(entity_id="invalid_entity_id") - with pytest.raises(HomeAssistantError): - await platform.async_add_entities([entity]) + entity2 = MockEntity(entity_id="valid.entity_id") + await platform.async_add_entities( + [entity, entity2], update_before_add=update_before_add + ) assert entity.hass is None assert entity.platform is None + assert "Invalid entity ID: invalid_entity_id" in caplog.text + # Ensure the valid entity was still added + assert entity2.hass is not None + assert entity2.platform is not None class MockBlockingEntity(MockEntity): @@ -1728,16 +1737,21 @@ class MockBlockingEntity(MockEntity): await asyncio.sleep(1000) +@pytest.mark.parametrize("update_before_add", (True, False)) async def test_setup_entry_with_entities_that_block_forever( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, entity_registry: er.EntityRegistry, + update_before_add: bool, ) -> None: """Test we cancel adding entities when we reach the timeout.""" async def async_setup_entry(hass, config_entry, async_add_entities): """Mock setup entry method.""" - async_add_entities([MockBlockingEntity(name="test1", unique_id="unique")]) + async_add_entities( + [MockBlockingEntity(name="test1", unique_id="unique")], + update_before_add=update_before_add, + ) return True platform = MockPlatform(async_setup_entry=async_setup_entry) @@ -1761,7 +1775,47 @@ async def test_setup_entry_with_entities_that_block_forever( assert "test" in caplog.text -async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None: +class MockCancellingEntity(MockEntity): + """Class to mock an entity get cancelled while adding.""" + + async def async_added_to_hass(self): + """Mock cancellation.""" + raise asyncio.CancelledError + + +@pytest.mark.parametrize("update_before_add", (True, False)) +async def test_cancellation_is_not_blocked( + hass: HomeAssistant, + update_before_add: bool, +) -> None: + """Test cancellation is not blocked while adding entities.""" + + async def async_setup_entry(hass, config_entry, async_add_entities): + """Mock setup entry method.""" + async_add_entities( + [MockCancellingEntity(name="test1", unique_id="unique")], + update_before_add=update_before_add, + ) + return True + + platform = MockPlatform(async_setup_entry=async_setup_entry) + config_entry = MockConfigEntry(entry_id="super-mock-id") + platform = MockEntityPlatform( + hass, platform_name=config_entry.domain, platform=platform + ) + + with pytest.raises(asyncio.CancelledError): + assert await platform.async_setup_entry(config_entry) + await hass.async_block_till_done() + + full_name = f"{config_entry.domain}.{platform.domain}" + assert full_name not in hass.config.components + + +@pytest.mark.parametrize("update_before_add", (True, False)) +async def test_two_platforms_add_same_entity( + hass: HomeAssistant, update_before_add: bool +) -> None: """Test two platforms in the same domain adding an entity with the same name.""" entity_platform1 = MockEntityPlatform( hass, domain="mock_integration", platform_name="mock_platform", platform=None @@ -1774,8 +1828,12 @@ async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None: entity2 = SlowEntity(name="entity_1") await asyncio.gather( - entity_platform1.async_add_entities([entity1]), - entity_platform2.async_add_entities([entity2]), + entity_platform1.async_add_entities( + [entity1], update_before_add=update_before_add + ), + entity_platform2.async_add_entities( + [entity2], update_before_add=update_before_add + ), ) entities = [] @@ -1816,12 +1874,14 @@ class SlowEntity(MockEntity): (True, None, "test_domain.device_bla"), ), ) +@pytest.mark.parametrize("update_before_add", (True, False)) async def test_entity_name_influences_entity_id( hass: HomeAssistant, entity_registry: er.EntityRegistry, has_entity_name: bool, entity_name: str | None, expected_entity_id: str, + update_before_add: bool, ) -> None: """Test entity_id is influenced by entity name.""" @@ -1839,7 +1899,8 @@ async def test_entity_name_influences_entity_id( has_entity_name=has_entity_name, name=entity_name, ), - ] + ], + update_before_add=update_before_add, ) return True @@ -1867,12 +1928,14 @@ async def test_entity_name_influences_entity_id( ("cn", True, "test_domain.device_bla_english_name"), ), ) +@pytest.mark.parametrize("update_before_add", (True, False)) async def test_translated_entity_name_influences_entity_id( hass: HomeAssistant, entity_registry: er.EntityRegistry, language: str, has_entity_name: bool, expected_entity_id: str, + update_before_add: bool, ) -> None: """Test entity_id is influenced by translated entity name.""" @@ -1909,7 +1972,9 @@ async def test_translated_entity_name_influences_entity_id( async def async_setup_entry(hass, config_entry, async_add_entities): """Mock setup entry method.""" - async_add_entities([TranslatedEntity(has_entity_name)]) + async_add_entities( + [TranslatedEntity(has_entity_name)], update_before_add=update_before_add + ) return True platform = MockPlatform(async_setup_entry=async_setup_entry)