Index entities by domain for entity services (#106759)

This commit is contained in:
J. Nick Koston 2024-01-02 04:28:58 -10:00 committed by GitHub
parent bf0d891f68
commit 09b65f14b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 60 deletions

View file

@ -58,7 +58,6 @@ from .typing import ConfigType, TemplateVarsType
if TYPE_CHECKING:
from .entity import Entity
from .entity_platform import EntityPlatform
_EntityT = TypeVar("_EntityT", bound=Entity)
@ -741,7 +740,7 @@ def async_set_service_schema(
def _get_permissible_entity_candidates(
call: ServiceCall,
platforms: Iterable[EntityPlatform],
entities: dict[str, Entity],
entity_perms: None | (Callable[[str, str], bool]),
target_all_entities: bool,
all_referenced: set[str] | None,
@ -754,9 +753,8 @@ def _get_permissible_entity_candidates(
# is allowed to control.
return [
entity
for platform in platforms
for entity in platform.entities.values()
if entity_perms(entity.entity_id, POLICY_CONTROL)
for entity_id, entity in entities.items()
if entity_perms(entity_id, POLICY_CONTROL)
]
assert all_referenced is not None
@ -771,29 +769,26 @@ def _get_permissible_entity_candidates(
)
elif target_all_entities:
return [
entity for platform in platforms for entity in platform.entities.values()
]
return list(entities.values())
# We have already validated they have permissions to control all_referenced
# entities so we do not need to check again.
assert all_referenced is not None
if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]:
for platform in platforms:
if (entity := platform.entities.get(single_entity)) is not None:
return [entity]
if TYPE_CHECKING:
assert all_referenced is not None
if (
len(all_referenced) == 1
and (single_entity := list(all_referenced)[0])
and (entity := entities.get(single_entity)) is not None
):
return [entity]
return [
platform.entities[entity_id]
for platform in platforms
for entity_id in all_referenced.intersection(platform.entities)
]
return [entities[entity_id] for entity_id in all_referenced.intersection(entities)]
@bind_hass
async def entity_service_call(
hass: HomeAssistant,
platforms: Iterable[EntityPlatform],
registered_entities: dict[str, Entity],
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
call: ServiceCall,
required_features: Iterable[int] | None = None,
@ -832,7 +827,7 @@ async def entity_service_call(
# A list with entities to call the service on.
entity_candidates = _get_permissible_entity_candidates(
call,
platforms,
registered_entities,
entity_perms,
target_all_entities,
all_referenced,