Index entities by domain for entity services (#106759)
This commit is contained in:
parent
bf0d891f68
commit
09b65f14b9
6 changed files with 63 additions and 60 deletions
|
@ -89,12 +89,13 @@ class EntityComponent(Generic[_EntityT]):
|
|||
|
||||
self.config: ConfigType | None = None
|
||||
|
||||
domain_platform = self._async_init_entity_platform(domain, None)
|
||||
self._platforms: dict[
|
||||
str | tuple[str, timedelta | None, str | None], EntityPlatform
|
||||
] = {domain: self._async_init_entity_platform(domain, None)}
|
||||
self.async_add_entities = self._platforms[domain].async_add_entities
|
||||
self.add_entities = self._platforms[domain].add_entities
|
||||
|
||||
] = {domain: domain_platform}
|
||||
self.async_add_entities = domain_platform.async_add_entities
|
||||
self.add_entities = domain_platform.add_entities
|
||||
self._entities: dict[str, entity.Entity] = domain_platform.domain_entities
|
||||
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
|
||||
|
||||
@property
|
||||
|
@ -105,18 +106,11 @@ class EntityComponent(Generic[_EntityT]):
|
|||
callers that iterate over this asynchronously should make a copy
|
||||
using list() before iterating.
|
||||
"""
|
||||
return chain.from_iterable(
|
||||
platform.entities.values() # type: ignore[misc]
|
||||
for platform in self._platforms.values()
|
||||
)
|
||||
return self._entities.values() # type: ignore[return-value]
|
||||
|
||||
def get_entity(self, entity_id: str) -> _EntityT | None:
|
||||
"""Get an entity."""
|
||||
for platform in self._platforms.values():
|
||||
entity_obj = platform.entities.get(entity_id)
|
||||
if entity_obj is not None:
|
||||
return entity_obj # type: ignore[return-value]
|
||||
return None
|
||||
return self._entities.get(entity_id) # type: ignore[return-value]
|
||||
|
||||
def register_shutdown(self) -> None:
|
||||
"""Register shutdown on Home Assistant STOP event.
|
||||
|
@ -237,7 +231,7 @@ class EntityComponent(Generic[_EntityT]):
|
|||
"""Handle the service."""
|
||||
|
||||
result = await service.entity_service_call(
|
||||
self.hass, self._platforms.values(), func, call, required_features
|
||||
self.hass, self._entities, func, call, required_features
|
||||
)
|
||||
|
||||
if result:
|
||||
|
@ -270,7 +264,7 @@ class EntityComponent(Generic[_EntityT]):
|
|||
) -> EntityServiceResponse | None:
|
||||
"""Handle the service."""
|
||||
return await service.entity_service_call(
|
||||
self.hass, self._platforms.values(), func, call, required_features
|
||||
self.hass, self._entities, func, call, required_features
|
||||
)
|
||||
|
||||
self.hass.services.async_register(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue