Index entities by domain for entity services (#106759)

This commit is contained in:
J. Nick Koston 2024-01-02 04:28:58 -10:00 committed by GitHub
parent bf0d891f68
commit 09b65f14b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 60 deletions

View file

@ -89,12 +89,13 @@ class EntityComponent(Generic[_EntityT]):
self.config: ConfigType | None = None
domain_platform = self._async_init_entity_platform(domain, None)
self._platforms: dict[
str | tuple[str, timedelta | None, str | None], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)}
self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities
] = {domain: domain_platform}
self.async_add_entities = domain_platform.async_add_entities
self.add_entities = domain_platform.add_entities
self._entities: dict[str, entity.Entity] = domain_platform.domain_entities
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property
@ -105,18 +106,11 @@ class EntityComponent(Generic[_EntityT]):
callers that iterate over this asynchronously should make a copy
using list() before iterating.
"""
return chain.from_iterable(
platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
)
return self._entities.values() # type: ignore[return-value]
def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj # type: ignore[return-value]
return None
return self._entities.get(entity_id) # type: ignore[return-value]
def register_shutdown(self) -> None:
"""Register shutdown on Home Assistant STOP event.
@ -237,7 +231,7 @@ class EntityComponent(Generic[_EntityT]):
"""Handle the service."""
result = await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features
self.hass, self._entities, func, call, required_features
)
if result:
@ -270,7 +264,7 @@ class EntityComponent(Generic[_EntityT]):
) -> EntityServiceResponse | None:
"""Handle the service."""
return await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features
self.hass, self._entities, func, call, required_features
)
self.hass.services.async_register(