Reduce overhead to call entity services (#106908)

This commit is contained in:
J. Nick Koston 2024-01-07 22:30:52 -10:00 committed by GitHub
parent 9ad3c8dbc9
commit d260ed938a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 36 deletions

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from datetime import timedelta from datetime import timedelta
from functools import partial
from itertools import chain from itertools import chain
import logging import logging
from types import ModuleType from types import ModuleType
@ -20,8 +21,8 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.core import ( from homeassistant.core import (
EntityServiceResponse,
Event, Event,
HassJob,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResponse, ServiceResponse,
@ -225,13 +226,16 @@ class EntityComponent(Generic[_EntityT]):
if isinstance(schema, dict): if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema) schema = cv.make_entity_service_schema(schema)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
async def handle_service( async def handle_service(
call: ServiceCall, call: ServiceCall,
) -> ServiceResponse: ) -> ServiceResponse:
"""Handle the service.""" """Handle the service."""
result = await service.entity_service_call( result = await service.entity_service_call(
self.hass, self._entities, func, call, required_features self.hass, self._entities, service_func, call, required_features
) )
if result: if result:
@ -259,16 +263,21 @@ class EntityComponent(Generic[_EntityT]):
if isinstance(schema, dict): if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema) schema = cv.make_entity_service_schema(schema)
async def handle_service( service_func: str | HassJob[..., Any]
call: ServiceCall, service_func = func if isinstance(func, str) else HassJob(func)
) -> EntityServiceResponse | None:
"""Handle the service."""
return await service.entity_service_call(
self.hass, self._entities, func, call, required_features
)
self.hass.services.async_register( self.hass.services.async_register(
self.domain, name, handle_service, schema, supports_response self.domain,
name,
partial(
service.entity_service_call,
self.hass,
self._entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
) )
async def async_setup_platform( async def async_setup_platform(

View file

@ -5,6 +5,7 @@ import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable from collections.abc import Awaitable, Callable, Coroutine, Iterable
from contextvars import ContextVar from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
from logging import Logger, getLogger from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
@ -20,7 +21,7 @@ from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
DOMAIN as HOMEASSISTANT_DOMAIN, DOMAIN as HOMEASSISTANT_DOMAIN,
CoreState, CoreState,
EntityServiceResponse, HassJob,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
SupportsResponse, SupportsResponse,
@ -833,18 +834,21 @@ class EntityPlatform:
if isinstance(schema, dict): if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema) schema = cv.make_entity_service_schema(schema)
async def handle_service(call: ServiceCall) -> EntityServiceResponse | None: service_func: str | HassJob[..., Any]
"""Handle the service.""" service_func = func if isinstance(func, str) else HassJob(func)
return await service.entity_service_call(
self.hass,
self.domain_entities,
func,
call,
required_features,
)
self.hass.services.async_register( self.hass.services.async_register(
self.platform_name, name, handle_service, schema, supports_response self.platform_name,
name,
partial(
service.entity_service_call,
self.hass,
self.domain_entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
) )
async def _update_entity_states(self, now: datetime) -> None: async def _update_entity_states(self, now: datetime) -> None:

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable from collections.abc import Awaitable, Callable, Iterable
import dataclasses import dataclasses
from enum import Enum from enum import Enum
from functools import cache, partial, wraps from functools import cache, partial, wraps
@ -29,6 +29,7 @@ from homeassistant.const import (
from homeassistant.core import ( from homeassistant.core import (
Context, Context,
EntityServiceResponse, EntityServiceResponse,
HassJob,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResponse, ServiceResponse,
@ -191,11 +192,14 @@ class ServiceParams(TypedDict):
class ServiceTargetSelector: class ServiceTargetSelector:
"""Class to hold a target selector for a service.""" """Class to hold a target selector for a service."""
__slots__ = ("entity_ids", "device_ids", "area_ids")
def __init__(self, service_call: ServiceCall) -> None: def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data.""" """Extract ids from service call data."""
entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID) service_call_data = service_call.data
device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID) entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
area_ids: str | list | None = service_call.data.get(ATTR_AREA_ID) device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
self.entity_ids = ( self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
@ -790,7 +794,7 @@ def _get_permissible_entity_candidates(
async def entity_service_call( async def entity_service_call(
hass: HomeAssistant, hass: HomeAssistant,
registered_entities: dict[str, Entity], registered_entities: dict[str, Entity],
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], func: str | HassJob,
call: ServiceCall, call: ServiceCall,
required_features: Iterable[int] | None = None, required_features: Iterable[int] | None = None,
) -> EntityServiceResponse | None: ) -> EntityServiceResponse | None:
@ -926,7 +930,7 @@ async def entity_service_call(
async def _handle_entity_call( async def _handle_entity_call(
hass: HomeAssistant, hass: HomeAssistant,
entity: Entity, entity: Entity,
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], func: str | HassJob,
data: dict | ServiceCall, data: dict | ServiceCall,
context: Context, context: Context,
) -> ServiceResponse: ) -> ServiceResponse:
@ -935,11 +939,11 @@ async def _handle_entity_call(
task: asyncio.Future[ServiceResponse] | None task: asyncio.Future[ServiceResponse] | None
if isinstance(func, str): if isinstance(func, str):
task = hass.async_run_job( task = hass.async_run_hass_job(
partial(getattr(entity, func), **data) # type: ignore[arg-type] HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type]
) )
else: else:
task = hass.async_run_job(func, entity, data) task = hass.async_run_hass_job(func, entity, data)
# Guard because callback functions do not return a task when passed to # Guard because callback functions do not return a task when passed to
# async_run_job. # async_run_job.

View file

@ -19,7 +19,13 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
EntityCategory, EntityCategory,
) )
from homeassistant.core import Context, HomeAssistant, ServiceCall, SupportsResponse from homeassistant.core import (
Context,
HassJob,
HomeAssistant,
ServiceCall,
SupportsResponse,
)
from homeassistant.helpers import ( from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
@ -803,7 +809,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
await service.entity_service_call( await service.entity_service_call(
hass, hass,
mock_entities, mock_entities,
test_service_mock, HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A], required_features=[SUPPORT_A],
) )
@ -822,7 +828,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
await service.entity_service_call( await service.entity_service_call(
hass, hass,
mock_entities, mock_entities,
test_service_mock, HassJob(test_service_mock),
ServiceCall( ServiceCall(
"test_domain", "test_service", {"entity_id": "light.living_room"} "test_domain", "test_service", {"entity_id": "light.living_room"}
), ),
@ -839,7 +845,7 @@ async def test_call_with_both_required_features(
await service.entity_service_call( await service.entity_service_call(
hass, hass,
mock_entities, mock_entities,
test_service_mock, HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A | SUPPORT_B], required_features=[SUPPORT_A | SUPPORT_B],
) )
@ -858,7 +864,7 @@ async def test_call_with_one_of_required_features(
await service.entity_service_call( await service.entity_service_call(
hass, hass,
mock_entities, mock_entities,
test_service_mock, HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A, SUPPORT_C], required_features=[SUPPORT_A, SUPPORT_C],
) )
@ -879,7 +885,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None:
await service.entity_service_call( await service.entity_service_call(
hass, hass,
mock_entities, mock_entities,
test_service_mock, HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
) )
assert test_service_mock.call_count == 1 assert test_service_mock.call_count == 1