Index entities by domain for entity services (#106759)

This commit is contained in:
J. Nick Koston 2024-01-02 04:28:58 -10:00 committed by GitHub
parent bf0d891f68
commit 09b65f14b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 60 deletions

View file

@ -14,6 +14,7 @@ from homeassistant.const import (
) )
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.helpers.service import entity_service_call from homeassistant.helpers.service import entity_service_call
@ -120,6 +121,14 @@ SERVICE_SEND_PROGRAM_COMMAND_SCHEMA = vol.All(
) )
def async_get_entities(hass: HomeAssistant) -> dict[str, Entity]:
"""Get entities for a domain."""
entities: dict[str, Entity] = {}
for platform in async_get_platforms(hass, DOMAIN):
entities.update(platform.entities)
return entities
@callback @callback
def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901 def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
"""Create and register services for the ISY integration.""" """Create and register services for the ISY integration."""
@ -159,7 +168,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_send_raw_node_command(call: ServiceCall) -> None: async def _async_send_raw_node_command(call: ServiceCall) -> None:
await entity_service_call( await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_send_raw_node_command", call hass, async_get_entities(hass), "async_send_raw_node_command", call
) )
hass.services.async_register( hass.services.async_register(
@ -171,7 +180,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_send_node_command(call: ServiceCall) -> None: async def _async_send_node_command(call: ServiceCall) -> None:
await entity_service_call( await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_send_node_command", call hass, async_get_entities(hass), "async_send_node_command", call
) )
hass.services.async_register( hass.services.async_register(
@ -183,7 +192,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_get_zwave_parameter(call: ServiceCall) -> None: async def _async_get_zwave_parameter(call: ServiceCall) -> None:
await entity_service_call( await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_get_zwave_parameter", call hass, async_get_entities(hass), "async_get_zwave_parameter", call
) )
hass.services.async_register( hass.services.async_register(
@ -195,7 +204,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_set_zwave_parameter(call: ServiceCall) -> None: async def _async_set_zwave_parameter(call: ServiceCall) -> None:
await entity_service_call( await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_set_zwave_parameter", call hass, async_get_entities(hass), "async_set_zwave_parameter", call
) )
hass.services.async_register( hass.services.async_register(
@ -207,7 +216,7 @@ def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
async def _async_rename_node(call: ServiceCall) -> None: async def _async_rename_node(call: ServiceCall) -> None:
await entity_service_call( await entity_service_call(
hass, async_get_platforms(hass, DOMAIN), "async_rename_node", call hass, async_get_entities(hass), "async_rename_node", call
) )
hass.services.async_register( hass.services.async_register(

View file

@ -89,12 +89,13 @@ class EntityComponent(Generic[_EntityT]):
self.config: ConfigType | None = None self.config: ConfigType | None = None
domain_platform = self._async_init_entity_platform(domain, None)
self._platforms: dict[ self._platforms: dict[
str | tuple[str, timedelta | None, str | None], EntityPlatform str | tuple[str, timedelta | None, str | None], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)} ] = {domain: domain_platform}
self.async_add_entities = self._platforms[domain].async_add_entities self.async_add_entities = domain_platform.async_add_entities
self.add_entities = self._platforms[domain].add_entities self.add_entities = domain_platform.add_entities
self._entities: dict[str, entity.Entity] = domain_platform.domain_entities
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property @property
@ -105,18 +106,11 @@ class EntityComponent(Generic[_EntityT]):
callers that iterate over this asynchronously should make a copy callers that iterate over this asynchronously should make a copy
using list() before iterating. using list() before iterating.
""" """
return chain.from_iterable( return self._entities.values() # type: ignore[return-value]
platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
)
def get_entity(self, entity_id: str) -> _EntityT | None: def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity.""" """Get an entity."""
for platform in self._platforms.values(): return self._entities.get(entity_id) # type: ignore[return-value]
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj # type: ignore[return-value]
return None
def register_shutdown(self) -> None: def register_shutdown(self) -> None:
"""Register shutdown on Home Assistant STOP event. """Register shutdown on Home Assistant STOP event.
@ -237,7 +231,7 @@ class EntityComponent(Generic[_EntityT]):
"""Handle the service.""" """Handle the service."""
result = await service.entity_service_call( result = await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features self.hass, self._entities, func, call, required_features
) )
if result: if result:
@ -270,7 +264,7 @@ class EntityComponent(Generic[_EntityT]):
) -> EntityServiceResponse | None: ) -> EntityServiceResponse | None:
"""Handle the service.""" """Handle the service."""
return await service.entity_service_call( return await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features self.hass, self._entities, func, call, required_features
) )
self.hass.services.async_register( self.hass.services.async_register(

View file

@ -55,6 +55,7 @@ SLOW_ADD_MIN_TIMEOUT = 500
PLATFORM_NOT_READY_RETRIES = 10 PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform" DATA_ENTITY_PLATFORM = "entity_platform"
DATA_DOMAIN_ENTITIES = "domain_entities"
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@ -147,6 +148,10 @@ class EntityPlatform:
self.platform_name, [] self.platform_name, []
).append(self) ).append(self)
self.domain_entities: dict[str, Entity] = hass.data.setdefault(
DATA_DOMAIN_ENTITIES, {}
).setdefault(domain, {})
def __repr__(self) -> str: def __repr__(self) -> str:
"""Represent an EntityPlatform.""" """Represent an EntityPlatform."""
return ( return (
@ -734,6 +739,7 @@ class EntityPlatform:
entity_id = entity.entity_id entity_id = entity.entity_id
self.entities[entity_id] = entity self.entities[entity_id] = entity
self.domain_entities[entity_id] = entity
if not restored: if not restored:
# Reserve the state in the state machine # Reserve the state in the state machine
@ -746,6 +752,7 @@ class EntityPlatform:
def remove_entity_cb() -> None: def remove_entity_cb() -> None:
"""Remove entity from entities dict.""" """Remove entity from entities dict."""
self.entities.pop(entity_id) self.entities.pop(entity_id)
self.domain_entities.pop(entity_id)
entity.async_on_remove(remove_entity_cb) entity.async_on_remove(remove_entity_cb)
@ -830,11 +837,7 @@ class EntityPlatform:
"""Handle the service.""" """Handle the service."""
return await service.entity_service_call( return await service.entity_service_call(
self.hass, self.hass,
[ self.domain_entities,
plf
for plf in self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name]
if plf.domain == self.domain
],
func, func,
call, call,
required_features, required_features,

View file

@ -58,7 +58,6 @@ from .typing import ConfigType, TemplateVarsType
if TYPE_CHECKING: if TYPE_CHECKING:
from .entity import Entity from .entity import Entity
from .entity_platform import EntityPlatform
_EntityT = TypeVar("_EntityT", bound=Entity) _EntityT = TypeVar("_EntityT", bound=Entity)
@ -741,7 +740,7 @@ def async_set_service_schema(
def _get_permissible_entity_candidates( def _get_permissible_entity_candidates(
call: ServiceCall, call: ServiceCall,
platforms: Iterable[EntityPlatform], entities: dict[str, Entity],
entity_perms: None | (Callable[[str, str], bool]), entity_perms: None | (Callable[[str, str], bool]),
target_all_entities: bool, target_all_entities: bool,
all_referenced: set[str] | None, all_referenced: set[str] | None,
@ -754,9 +753,8 @@ def _get_permissible_entity_candidates(
# is allowed to control. # is allowed to control.
return [ return [
entity entity
for platform in platforms for entity_id, entity in entities.items()
for entity in platform.entities.values() if entity_perms(entity_id, POLICY_CONTROL)
if entity_perms(entity.entity_id, POLICY_CONTROL)
] ]
assert all_referenced is not None assert all_referenced is not None
@ -771,29 +769,26 @@ def _get_permissible_entity_candidates(
) )
elif target_all_entities: elif target_all_entities:
return [ return list(entities.values())
entity for platform in platforms for entity in platform.entities.values()
]
# We have already validated they have permissions to control all_referenced # We have already validated they have permissions to control all_referenced
# entities so we do not need to check again. # entities so we do not need to check again.
if TYPE_CHECKING:
assert all_referenced is not None assert all_referenced is not None
if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]: if (
for platform in platforms: len(all_referenced) == 1
if (entity := platform.entities.get(single_entity)) is not None: and (single_entity := list(all_referenced)[0])
and (entity := entities.get(single_entity)) is not None
):
return [entity] return [entity]
return [ return [entities[entity_id] for entity_id in all_referenced.intersection(entities)]
platform.entities[entity_id]
for platform in platforms
for entity_id in all_referenced.intersection(platform.entities)
]
@bind_hass @bind_hass
async def entity_service_call( async def entity_service_call(
hass: HomeAssistant, hass: HomeAssistant,
platforms: Iterable[EntityPlatform], registered_entities: dict[str, Entity],
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
call: ServiceCall, call: ServiceCall,
required_features: Iterable[int] | None = None, required_features: Iterable[int] | None = None,
@ -832,7 +827,7 @@ async def entity_service_call(
# A list with entities to call the service on. # A list with entities to call the service on.
entity_candidates = _get_permissible_entity_candidates( entity_candidates = _get_permissible_entity_candidates(
call, call,
platforms, registered_entities,
entity_perms, entity_perms,
target_all_entities, target_all_entities,
all_referenced, all_referenced,

View file

@ -1406,7 +1406,9 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
with patch.dict( with patch.dict(
hass.data[tts.DATA_TTS_MANAGER].providers, {}, clear=True hass.data[tts.DATA_TTS_MANAGER].providers, {}, clear=True
), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True): ), patch.dict(hass.data[tts.DOMAIN]._platforms, {}, clear=True), patch.dict(
hass.data[tts.DOMAIN]._entities, {}, clear=True
):
assert tts.async_resolve_engine(hass, None) is None assert tts.async_resolve_engine(hass, None) is None
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}): with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):

View file

@ -802,7 +802,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
test_service_mock = AsyncMock(return_value=None) test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
test_service_mock, test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A], required_features=[SUPPORT_A],
@ -821,7 +821,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
with pytest.raises(exceptions.HomeAssistantError): with pytest.raises(exceptions.HomeAssistantError):
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
test_service_mock, test_service_mock,
ServiceCall( ServiceCall(
"test_domain", "test_service", {"entity_id": "light.living_room"} "test_domain", "test_service", {"entity_id": "light.living_room"}
@ -838,7 +838,7 @@ async def test_call_with_both_required_features(
test_service_mock = AsyncMock(return_value=None) test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
test_service_mock, test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A | SUPPORT_B], required_features=[SUPPORT_A | SUPPORT_B],
@ -857,7 +857,7 @@ async def test_call_with_one_of_required_features(
test_service_mock = AsyncMock(return_value=None) test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
test_service_mock, test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A, SUPPORT_C], required_features=[SUPPORT_A, SUPPORT_C],
@ -878,7 +878,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None:
test_service_mock = Mock(return_value=None) test_service_mock = Mock(return_value=None)
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
test_service_mock, test_service_mock,
ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
) )
@ -890,7 +890,7 @@ async def test_call_with_sync_attr(hass: HomeAssistant, mock_entities) -> None:
mock_method = mock_entities["light.kitchen"].sync_method = Mock(return_value=None) mock_method = mock_entities["light.kitchen"].sync_method = Mock(return_value=None)
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
"sync_method", "sync_method",
ServiceCall( ServiceCall(
"test_domain", "test_domain",
@ -908,7 +908,7 @@ async def test_call_context_user_not_exist(hass: HomeAssistant) -> None:
with pytest.raises(exceptions.UnknownUser) as err: with pytest.raises(exceptions.UnknownUser) as err:
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[], {},
Mock(), Mock(),
ServiceCall( ServiceCall(
"test_domain", "test_domain",
@ -935,7 +935,7 @@ async def test_call_context_target_all(
): ):
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall( ServiceCall(
"test_domain", "test_domain",
@ -963,7 +963,7 @@ async def test_call_context_target_specific(
): ):
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall( ServiceCall(
"test_domain", "test_domain",
@ -987,7 +987,7 @@ async def test_call_context_target_specific_no_auth(
): ):
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall( ServiceCall(
"test_domain", "test_domain",
@ -1007,7 +1007,7 @@ async def test_call_no_context_target_all(
"""Check we target all if no user context given.""" """Check we target all if no user context given."""
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall( ServiceCall(
"test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL} "test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL}
@ -1026,7 +1026,7 @@ async def test_call_no_context_target_specific(
"""Check we can target specified entities.""" """Check we can target specified entities."""
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall( ServiceCall(
"test_domain", "test_domain",
@ -1048,7 +1048,7 @@ async def test_call_with_match_all(
"""Check we only target allowed entities if targeting all.""" """Check we only target allowed entities if targeting all."""
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
) )
@ -1065,7 +1065,7 @@ async def test_call_with_omit_entity_id(
"""Check service call if we do not pass an entity ID.""" """Check service call if we do not pass an entity ID."""
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], mock_entities,
Mock(), Mock(),
ServiceCall("test_domain", "test_service"), ServiceCall("test_domain", "test_service"),
) )