Index entities by domain for entity services (#106759)
This commit is contained in:
parent
bf0d891f68
commit
09b65f14b9
6 changed files with 63 additions and 60 deletions
|
@ -14,6 +14,7 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_platform import async_get_platforms
|
||||
from homeassistant.helpers.service import entity_service_call
|
||||
|
||||
|
@ -120,6 +121,14 @@ SERVICE_SEND_PROGRAM_COMMAND_SCHEMA = vol.All(
|
|||
)
|
||||
|
||||
|
||||
def async_get_entities(hass: HomeAssistant) -> dict[str, Entity]:
|
||||
"""Get entities for a domain."""
|
||||
entities: dict[str, Entity] = {}
|
||||
for platform in async_get_platforms(hass, DOMAIN):
|
||||
entities.update(platform.entities)
|
||||
return entities
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
|
||||
"""Create and register services for the ISY integration."""
|
||||
|
@ -159,7 +168,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
|
|||
|
||||
async def _async_send_raw_node_command(call: ServiceCall) -> None:
|
||||
await entity_service_call(
|
||||
hass, async_get_platforms(hass, DOMAIN), "async_send_raw_node_command", call
|
||||
hass, async_get_entities(hass), "async_send_raw_node_command", call
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
|
@ -171,7 +180,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
|
|||
|
||||
async def _async_send_node_command(call: ServiceCall) -> None:
|
||||
await entity_service_call(
|
||||
hass, async_get_platforms(hass, DOMAIN), "async_send_node_command", call
|
||||
hass, async_get_entities(hass), "async_send_node_command", call
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
|
@ -183,7 +192,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
|
|||
|
||||
async def _async_get_zwave_parameter(call: ServiceCall) -> None:
|
||||
await entity_service_call(
|
||||
hass, async_get_platforms(hass, DOMAIN), "async_get_zwave_parameter", call
|
||||
hass, async_get_entities(hass), "async_get_zwave_parameter", call
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
|
@ -195,7 +204,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
|
|||
|
||||
async def _async_set_zwave_parameter(call: ServiceCall) -> None:
|
||||
await entity_service_call(
|
||||
hass, async_get_platforms(hass, DOMAIN), "async_set_zwave_parameter", call
|
||||
hass, async_get_entities(hass), "async_set_zwave_parameter", call
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
|
@ -207,7 +216,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
|
|||
|
||||
async def _async_rename_node(call: ServiceCall) -> None:
|
||||
await entity_service_call(
|
||||
hass, async_get_platforms(hass, DOMAIN), "async_rename_node", call
|
||||
hass, async_get_entities(hass), "async_rename_node", call
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
|
|
|
@ -89,12 +89,13 @@ class EntityComponent(Generic[_EntityT]):
|
|||
|
||||
self.config: ConfigType | None = None
|
||||
|
||||
domain_platform = self._async_init_entity_platform(domain, None)
|
||||
self._platforms: dict[
|
||||
str | tuple[str, timedelta | None, str | None], EntityPlatform
|
||||
] = {domain: self._async_init_entity_platform(domain, None)}
|
||||
self.async_add_entities = self._platforms[domain].async_add_entities
|
||||
self.add_entities = self._platforms[domain].add_entities
|
||||
|
||||
] = {domain: domain_platform}
|
||||
self.async_add_entities = domain_platform.async_add_entities
|
||||
self.add_entities = domain_platform.add_entities
|
||||
self._entities: dict[str, entity.Entity] = domain_platform.domain_entities
|
||||
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
|
||||
|
||||
@property
|
||||
|
@ -105,18 +106,11 @@ class EntityComponent(Generic[_EntityT]):
|
|||
callers that iterate over this asynchronously should make a copy
|
||||
using list() before iterating.
|
||||
"""
|
||||
return chain.from_iterable(
|
||||
platform.entities.values() # type: ignore[misc]
|
||||
for platform in self._platforms.values()
|
||||
)
|
||||
return self._entities.values() # type: ignore[return-value]
|
||||
|
||||
def get_entity(self, entity_id: str) -> _EntityT | None:
|
||||
"""Get an entity."""
|
||||
for platform in self._platforms.values():
|
||||
entity_obj = platform.entities.get(entity_id)
|
||||
if entity_obj is not None:
|
||||
return entity_obj # type: ignore[return-value]
|
||||
return None
|
||||
return self._entities.get(entity_id) # type: ignore[return-value]
|
||||
|
||||
def register_shutdown(self) -> None:
|
||||
"""Register shutdown on Home Assistant STOP event.
|
||||
|
@ -237,7 +231,7 @@ class EntityComponent(Generic[_EntityT]):
|
|||
"""Handle the service."""
|
||||
|
||||
result = await service.entity_service_call(
|
||||
self.hass, self._platforms.values(), func, call, required_features
|
||||
self.hass, self._entities, func, call, required_features
|
||||
)
|
||||
|
||||
if result:
|
||||
|
@ -270,7 +264,7 @@ class EntityComponent(Generic[_EntityT]):
|
|||
) -> EntityServiceResponse | None:
|
||||
"""Handle the service."""
|
||||
return await service.entity_service_call(
|
||||
self.hass, self._platforms.values(), func, call, required_features
|
||||
self.hass, self._entities, func, call, required_features
|
||||
)
|
||||
|
||||
self.hass.services.async_register(
|
||||
|
|
|
@ -55,6 +55,7 @@ SLOW_ADD_MIN_TIMEOUT = 500
|
|||
|
||||
PLATFORM_NOT_READY_RETRIES = 10
|
||||
DATA_ENTITY_PLATFORM = "entity_platform"
|
||||
DATA_DOMAIN_ENTITIES = "domain_entities"
|
||||
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
@ -147,6 +148,10 @@ class EntityPlatform:
|
|||
self.platform_name, []
|
||||
).append(self)
|
||||
|
||||
self.domain_entities: dict[str, Entity] = hass.data.setdefault(
|
||||
DATA_DOMAIN_ENTITIES, {}
|
||||
).setdefault(domain, {})
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent an EntityPlatform."""
|
||||
return (
|
||||
|
@ -734,6 +739,7 @@ class EntityPlatform:
|
|||
|
||||
entity_id = entity.entity_id
|
||||
self.entities[entity_id] = entity
|
||||
self.domain_entities[entity_id] = entity
|
||||
|
||||
if not restored:
|
||||
# Reserve the state in the state machine
|
||||
|
@ -746,6 +752,7 @@ class EntityPlatform:
|
|||
def remove_entity_cb() -> None:
|
||||
"""Remove entity from entities dict."""
|
||||
self.entities.pop(entity_id)
|
||||
self.domain_entities.pop(entity_id)
|
||||
|
||||
entity.async_on_remove(remove_entity_cb)
|
||||
|
||||
|
@ -830,11 +837,7 @@ class EntityPlatform:
|
|||
"""Handle the service."""
|
||||
return await service.entity_service_call(
|
||||
self.hass,
|
||||
[
|
||||
plf
|
||||
for plf in self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name]
|
||||
if plf.domain == self.domain
|
||||
],
|
||||
self.domain_entities,
|
||||
func,
|
||||
call,
|
||||
required_features,
|
||||
|
|
|
@ -58,7 +58,6 @@ from .typing import ConfigType, TemplateVarsType
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from .entity import Entity
|
||||
from .entity_platform import EntityPlatform
|
||||
|
||||
_EntityT = TypeVar("_EntityT", bound=Entity)
|
||||
|
||||
|
@ -741,7 +740,7 @@ def async_set_service_schema(
|
|||
|
||||
def _get_permissible_entity_candidates(
|
||||
call: ServiceCall,
|
||||
platforms: Iterable[EntityPlatform],
|
||||
entities: dict[str, Entity],
|
||||
entity_perms: None | (Callable[[str, str], bool]),
|
||||
target_all_entities: bool,
|
||||
all_referenced: set[str] | None,
|
||||
|
@ -754,9 +753,8 @@ def _get_permissible_entity_candidates(
|
|||
# is allowed to control.
|
||||
return [
|
||||
entity
|
||||
for platform in platforms
|
||||
for entity in platform.entities.values()
|
||||
if entity_perms(entity.entity_id, POLICY_CONTROL)
|
||||
for entity_id, entity in entities.items()
|
||||
if entity_perms(entity_id, POLICY_CONTROL)
|
||||
]
|
||||
|
||||
assert all_referenced is not None
|
||||
|
@ -771,29 +769,26 @@ def _get_permissible_entity_candidates(
|
|||
)
|
||||
|
||||
elif target_all_entities:
|
||||
return [
|
||||
entity for platform in platforms for entity in platform.entities.values()
|
||||
]
|
||||
return list(entities.values())
|
||||
|
||||
# We have already validated they have permissions to control all_referenced
|
||||
# entities so we do not need to check again.
|
||||
if TYPE_CHECKING:
|
||||
assert all_referenced is not None
|
||||
if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]:
|
||||
for platform in platforms:
|
||||
if (entity := platform.entities.get(single_entity)) is not None:
|
||||
if (
|
||||
len(all_referenced) == 1
|
||||
and (single_entity := list(all_referenced)[0])
|
||||
and (entity := entities.get(single_entity)) is not None
|
||||
):
|
||||
return [entity]
|
||||
|
||||
return [
|
||||
platform.entities[entity_id]
|
||||
for platform in platforms
|
||||
for entity_id in all_referenced.intersection(platform.entities)
|
||||
]
|
||||
return [entities[entity_id] for entity_id in all_referenced.intersection(entities)]
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def entity_service_call(
|
||||
hass: HomeAssistant,
|
||||
platforms: Iterable[EntityPlatform],
|
||||
registered_entities: dict[str, Entity],
|
||||
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
|
||||
call: ServiceCall,
|
||||
required_features: Iterable[int] | None = None,
|
||||
|
@ -832,7 +827,7 @@ async def entity_service_call(
|
|||
# A list with entities to call the service on.
|
||||
entity_candidates = _get_permissible_entity_candidates(
|
||||
call,
|
||||
platforms,
|
||||
registered_entities,
|
||||
entity_perms,
|
||||
target_all_entities,
|
||||
all_referenced,
|
||||
|
|
|
@ -1406,7 +1406,9 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
|
|||
|
||||
with patch.dict(
|
||||
hass.data[tts.DATA_TTS_MANAGER].providers, {}, clear=True
|
||||
), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True):
|
||||
), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True), patch.dict(
|
||||
hass.data[tts.DOMAIN]._entities, {}, clear=True
|
||||
):
|
||||
assert tts.async_resolve_engine(hass, None) is None
|
||||
|
||||
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):
|
||||
|
|
|
@ -802,7 +802,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
|
|||
test_service_mock = AsyncMock(return_value=None)
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
required_features=[SUPPORT_A],
|
||||
|
@ -821,7 +821,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
|
|||
with pytest.raises(exceptions.HomeAssistantError):
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
ServiceCall(
|
||||
"test_domain", "test_service", {"entity_id": "light.living_room"}
|
||||
|
@ -838,7 +838,7 @@ async def test_call_with_both_required_features(
|
|||
test_service_mock = AsyncMock(return_value=None)
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
required_features=[SUPPORT_A | SUPPORT_B],
|
||||
|
@ -857,7 +857,7 @@ async def test_call_with_one_of_required_features(
|
|||
test_service_mock = AsyncMock(return_value=None)
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
required_features=[SUPPORT_A, SUPPORT_C],
|
||||
|
@ -878,7 +878,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None:
|
|||
test_service_mock = Mock(return_value=None)
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
|
||||
)
|
||||
|
@ -890,7 +890,7 @@ async def test_call_with_sync_attr(hass: HomeAssistant, mock_entities) -> None:
|
|||
mock_method = mock_entities["light.kitchen"].sync_method = Mock(return_value=None)
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
"sync_method",
|
||||
ServiceCall(
|
||||
"test_domain",
|
||||
|
@ -908,7 +908,7 @@ async def test_call_context_user_not_exist(hass: HomeAssistant) -> None:
|
|||
with pytest.raises(exceptions.UnknownUser) as err:
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[],
|
||||
{},
|
||||
Mock(),
|
||||
ServiceCall(
|
||||
"test_domain",
|
||||
|
@ -935,7 +935,7 @@ async def test_call_context_target_all(
|
|||
):
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall(
|
||||
"test_domain",
|
||||
|
@ -963,7 +963,7 @@ async def test_call_context_target_specific(
|
|||
):
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall(
|
||||
"test_domain",
|
||||
|
@ -987,7 +987,7 @@ async def test_call_context_target_specific_no_auth(
|
|||
):
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall(
|
||||
"test_domain",
|
||||
|
@ -1007,7 +1007,7 @@ async def test_call_no_context_target_all(
|
|||
"""Check we target all if no user context given."""
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall(
|
||||
"test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL}
|
||||
|
@ -1026,7 +1026,7 @@ async def test_call_no_context_target_specific(
|
|||
"""Check we can target specified entities."""
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall(
|
||||
"test_domain",
|
||||
|
@ -1048,7 +1048,7 @@ async def test_call_with_match_all(
|
|||
"""Check we only target allowed entities if targeting all."""
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
)
|
||||
|
@ -1065,7 +1065,7 @@ async def test_call_with_omit_entity_id(
|
|||
"""Check service call if we do not pass an entity ID."""
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
mock_entities,
|
||||
Mock(),
|
||||
ServiceCall("test_domain", "test_service"),
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue