diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index b3eb8722997..e49acc71d07 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Iterable from datetime import timedelta +from functools import partial from itertools import chain import logging from types import ModuleType @@ -20,8 +21,8 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import ( - EntityServiceResponse, Event, + HassJob, HomeAssistant, ServiceCall, ServiceResponse, @@ -225,13 +226,16 @@ class EntityComponent(Generic[_EntityT]): if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) + service_func: str | HassJob[..., Any] + service_func = func if isinstance(func, str) else HassJob(func) + async def handle_service( call: ServiceCall, ) -> ServiceResponse: """Handle the service.""" result = await service.entity_service_call( - self.hass, self._entities, func, call, required_features + self.hass, self._entities, service_func, call, required_features ) if result: @@ -259,16 +263,21 @@ class EntityComponent(Generic[_EntityT]): if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) - async def handle_service( - call: ServiceCall, - ) -> EntityServiceResponse | None: - """Handle the service.""" - return await service.entity_service_call( - self.hass, self._entities, func, call, required_features - ) + service_func: str | HassJob[..., Any] + service_func = func if isinstance(func, str) else HassJob(func) self.hass.services.async_register( - self.domain, name, handle_service, schema, supports_response + self.domain, + name, + partial( + service.entity_service_call, + self.hass, + self._entities, + service_func, + required_features=required_features, + ), + schema, + supports_response, ) async def async_setup_platform( diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 1bf7d95135b..89eb44a0459 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -5,6 +5,7 @@ import asyncio from collections.abc import Awaitable, Callable, Coroutine, Iterable from contextvars import ContextVar from datetime import datetime, timedelta +from functools import partial from logging import Logger, getLogger from typing import TYPE_CHECKING, Any, Protocol @@ -20,7 +21,7 @@ from homeassistant.core import ( CALLBACK_TYPE, DOMAIN as HOMEASSISTANT_DOMAIN, CoreState, - EntityServiceResponse, + HassJob, HomeAssistant, ServiceCall, SupportsResponse, @@ -833,18 +834,21 @@ class EntityPlatform: if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) - async def handle_service(call: ServiceCall) -> EntityServiceResponse | None: - """Handle the service.""" - return await service.entity_service_call( - self.hass, - self.domain_entities, - func, - call, - required_features, - ) + service_func: str | HassJob[..., Any] + service_func = func if isinstance(func, str) else HassJob(func) self.hass.services.async_register( - self.platform_name, name, handle_service, schema, supports_response + self.platform_name, + name, + partial( + service.entity_service_call, + self.hass, + self.domain_entities, + service_func, + required_features=required_features, + ), + schema, + supports_response, ) async def _update_entity_states(self, now: datetime) -> None: diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 4813a54ac8b..656b2c21129 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable, Coroutine, Iterable +from collections.abc import Awaitable, Callable, Iterable import dataclasses from enum import Enum from functools import cache, partial, wraps @@ -29,6 +29,7 @@ from homeassistant.const import ( from homeassistant.core import ( Context, EntityServiceResponse, + HassJob, HomeAssistant, ServiceCall, ServiceResponse, @@ -191,11 +192,14 @@ class ServiceParams(TypedDict): class ServiceTargetSelector: """Class to hold a target selector for a service.""" + __slots__ = ("entity_ids", "device_ids", "area_ids") + def __init__(self, service_call: ServiceCall) -> None: """Extract ids from service call data.""" - entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID) - device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID) - area_ids: str | list | None = service_call.data.get(ATTR_AREA_ID) + service_call_data = service_call.data + entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID) + device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID) + area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID) self.entity_ids = ( set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() @@ -790,7 +794,7 @@ def _get_permissible_entity_candidates( async def entity_service_call( hass: HomeAssistant, registered_entities: dict[str, Entity], - func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], + func: str | HassJob, call: ServiceCall, required_features: Iterable[int] | None = None, ) -> EntityServiceResponse | None: @@ -926,7 +930,7 @@ async def entity_service_call( async def _handle_entity_call( hass: HomeAssistant, entity: Entity, - func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], + func: str | HassJob, data: dict | ServiceCall, context: Context, ) -> ServiceResponse: @@ -935,11 +939,11 @@ async def _handle_entity_call( task: asyncio.Future[ServiceResponse] | None if isinstance(func, str): - task = hass.async_run_job( - partial(getattr(entity, func), **data) # type: ignore[arg-type] + task = hass.async_run_hass_job( + HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type] ) else: - task = hass.async_run_job(func, entity, data) + task = hass.async_run_hass_job(func, entity, data) # Guard because callback functions do not return a task when passed to # async_run_job. diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 628ead473d7..07e68e081b3 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -19,7 +19,13 @@ from homeassistant.const import ( STATE_ON, EntityCategory, ) -from homeassistant.core import Context, HomeAssistant, ServiceCall, SupportsResponse +from homeassistant.core import ( + Context, + HassJob, + HomeAssistant, + ServiceCall, + SupportsResponse, +) from homeassistant.helpers import ( device_registry as dr, entity_registry as er, @@ -803,7 +809,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) - await service.entity_service_call( hass, mock_entities, - test_service_mock, + HassJob(test_service_mock), ServiceCall("test_domain", "test_service", {"entity_id": "all"}), required_features=[SUPPORT_A], ) @@ -822,7 +828,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) - await service.entity_service_call( hass, mock_entities, - test_service_mock, + HassJob(test_service_mock), ServiceCall( "test_domain", "test_service", {"entity_id": "light.living_room"} ), @@ -839,7 +845,7 @@ async def test_call_with_both_required_features( await service.entity_service_call( hass, mock_entities, - test_service_mock, + HassJob(test_service_mock), ServiceCall("test_domain", "test_service", {"entity_id": "all"}), required_features=[SUPPORT_A | SUPPORT_B], ) @@ -858,7 +864,7 @@ async def test_call_with_one_of_required_features( await service.entity_service_call( hass, mock_entities, - test_service_mock, + HassJob(test_service_mock), ServiceCall("test_domain", "test_service", {"entity_id": "all"}), required_features=[SUPPORT_A, SUPPORT_C], ) @@ -879,7 +885,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None: await service.entity_service_call( hass, mock_entities, - test_service_mock, + HassJob(test_service_mock), ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), ) assert test_service_mock.call_count == 1