Support multiple responses for service calls (#96370)
* add supports_response to platform entity services * support multiple entities in entity_service_call * support legacy response format for service calls * revert changes to script/shell_command * add back test for multiple responses for legacy service * remove SupportsResponse.ONLY_LEGACY * Apply suggestion Co-authored-by: Allen Porter <allen.porter@gmail.com> * test for entity_id remove None * revert Apply suggestion * return EntityServiceResponse from _handle_entity_call * Use asyncio.gather * EntityServiceResponse not Optional * styling --------- Co-authored-by: Allen Porter <allen.porter@gmail.com>
This commit is contained in:
parent
b86f3be510
commit
06c9719cd6
8 changed files with 277 additions and 37 deletions
|
@ -300,7 +300,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
async_create_event,
|
||||
required_features=[CalendarEntityFeature.CREATE_EVENT],
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
component.async_register_legacy_entity_service(
|
||||
SERVICE_LIST_EVENTS,
|
||||
SERVICE_LIST_EVENTS_SCHEMA,
|
||||
async_list_events_service,
|
||||
|
|
|
@ -210,7 +210,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
component = hass.data[DOMAIN] = EntityComponent[WeatherEntity](
|
||||
_LOGGER, DOMAIN, hass, SCAN_INTERVAL
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
component.async_register_legacy_entity_service(
|
||||
SERVICE_GET_FORECAST,
|
||||
{vol.Required("type"): vol.In(("daily", "hourly", "twice_daily"))},
|
||||
async_get_forecast_service,
|
||||
|
|
|
@ -134,6 +134,7 @@ DOMAIN = "homeassistant"
|
|||
BLOCK_LOG_TIMEOUT = 60
|
||||
|
||||
ServiceResponse = JsonObjectType | None
|
||||
EntityServiceResponse = dict[str, ServiceResponse]
|
||||
|
||||
|
||||
class ConfigSource(enum.StrEnum):
|
||||
|
@ -1773,7 +1774,10 @@ class Service:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None],
|
||||
func: Callable[
|
||||
[ServiceCall],
|
||||
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] | None,
|
||||
],
|
||||
schema: vol.Schema | None,
|
||||
domain: str,
|
||||
service: str,
|
||||
|
@ -1882,7 +1886,8 @@ class ServiceRegistry:
|
|||
domain: str,
|
||||
service: str,
|
||||
service_func: Callable[
|
||||
[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None
|
||||
[ServiceCall],
|
||||
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] | None,
|
||||
],
|
||||
schema: vol.Schema | None = None,
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
|
|
|
@ -20,6 +20,7 @@ from homeassistant.const import (
|
|||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
EntityServiceResponse,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
|
@ -217,6 +218,40 @@ class EntityComponent(Generic[_EntityT]):
|
|||
self.hass, self.entities, service_call, expand_group
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_register_legacy_entity_service(
|
||||
self,
|
||||
name: str,
|
||||
schema: dict[str | vol.Marker, Any] | vol.Schema,
|
||||
func: str | Callable[..., Any],
|
||||
required_features: list[int] | None = None,
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
) -> None:
|
||||
"""Register an entity service with a legacy response format."""
|
||||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Handle the service."""
|
||||
|
||||
result = await service.entity_service_call(
|
||||
self.hass, self._platforms.values(), func, call, required_features
|
||||
)
|
||||
|
||||
if result:
|
||||
if len(result) > 1:
|
||||
raise HomeAssistantError(
|
||||
"Deprecated service call matched more than one entity"
|
||||
)
|
||||
return result.popitem()[1]
|
||||
return None
|
||||
|
||||
self.hass.services.async_register(
|
||||
self.domain, name, handle_service, schema, supports_response
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_register_entity_service(
|
||||
self,
|
||||
|
@ -230,7 +265,9 @@ class EntityComponent(Generic[_EntityT]):
|
|||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(call: ServiceCall) -> ServiceResponse:
|
||||
async def handle_service(
|
||||
call: ServiceCall,
|
||||
) -> EntityServiceResponse | None:
|
||||
"""Handle the service."""
|
||||
return await service.entity_service_call(
|
||||
self.hass, self._platforms.values(), func, call, required_features
|
||||
|
|
|
@ -20,8 +20,10 @@ from homeassistant.core import (
|
|||
CALLBACK_TYPE,
|
||||
DOMAIN as HOMEASSISTANT_DOMAIN,
|
||||
CoreState,
|
||||
EntityServiceResponse,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
split_entity_id,
|
||||
valid_entity_id,
|
||||
|
@ -814,6 +816,7 @@ class EntityPlatform:
|
|||
schema: dict[str, Any] | vol.Schema,
|
||||
func: str | Callable[..., Any],
|
||||
required_features: Iterable[int] | None = None,
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
) -> None:
|
||||
"""Register an entity service.
|
||||
|
||||
|
@ -825,9 +828,9 @@ class EntityPlatform:
|
|||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(call: ServiceCall) -> None:
|
||||
async def handle_service(call: ServiceCall) -> EntityServiceResponse | None:
|
||||
"""Handle the service."""
|
||||
await service.entity_service_call(
|
||||
return await service.entity_service_call(
|
||||
self.hass,
|
||||
[
|
||||
plf
|
||||
|
@ -840,7 +843,7 @@ class EntityPlatform:
|
|||
)
|
||||
|
||||
self.hass.services.async_register(
|
||||
self.platform_name, name, handle_service, schema
|
||||
self.platform_name, name, handle_service, schema, supports_response
|
||||
)
|
||||
|
||||
async def _update_entity_states(self, now: datetime) -> None:
|
||||
|
|
|
@ -28,6 +28,7 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import (
|
||||
Context,
|
||||
EntityServiceResponse,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
|
@ -790,7 +791,7 @@ async def entity_service_call(
|
|||
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
|
||||
call: ServiceCall,
|
||||
required_features: Iterable[int] | None = None,
|
||||
) -> ServiceResponse | None:
|
||||
) -> EntityServiceResponse | None:
|
||||
"""Handle an entity service call.
|
||||
|
||||
Calls all platforms simultaneously.
|
||||
|
@ -870,10 +871,9 @@ async def entity_service_call(
|
|||
return None
|
||||
|
||||
if len(entities) == 1:
|
||||
# Single entity case avoids creating tasks and allows returning
|
||||
# ServiceResponse
|
||||
# Single entity case avoids creating task
|
||||
entity = entities[0]
|
||||
response_data = await _handle_entity_call(
|
||||
single_response = await _handle_entity_call(
|
||||
hass, entity, func, data, call.context
|
||||
)
|
||||
if entity.should_poll:
|
||||
|
@ -881,27 +881,25 @@ async def entity_service_call(
|
|||
# Set context again so it's there when we update
|
||||
entity.async_set_context(call.context)
|
||||
await entity.async_update_ha_state(True)
|
||||
return response_data if return_response else None
|
||||
return {entity.entity_id: single_response} if return_response else None
|
||||
|
||||
if return_response:
|
||||
raise HomeAssistantError(
|
||||
"Service call requested response data but matched more than one entity"
|
||||
)
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(
|
||||
entity.async_request_call(
|
||||
_handle_entity_call(hass, entity, func, data, call.context)
|
||||
)
|
||||
# Use asyncio.gather here to ensure the returned results
|
||||
# are in the same order as the entities list
|
||||
results: list[ServiceResponse] = await asyncio.gather(
|
||||
*[
|
||||
entity.async_request_call(
|
||||
_handle_entity_call(hass, entity, func, data, call.context)
|
||||
)
|
||||
for entity in entities
|
||||
]
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
assert not pending
|
||||
|
||||
for task in done:
|
||||
task.result() # pop exception if have
|
||||
response_data: EntityServiceResponse = {}
|
||||
for entity, result in zip(entities, results):
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
response_data[entity.entity_id] = result
|
||||
|
||||
tasks: list[asyncio.Task[None]] = []
|
||||
|
||||
|
@ -920,7 +918,7 @@ async def entity_service_call(
|
|||
for future in done:
|
||||
future.result() # pop exception if have
|
||||
|
||||
return None
|
||||
return response_data if return_response and response_data else None
|
||||
|
||||
|
||||
async def _handle_entity_call(
|
||||
|
@ -943,7 +941,7 @@ async def _handle_entity_call(
|
|||
|
||||
# Guard because callback functions do not return a task when passed to
|
||||
# async_run_job.
|
||||
result: ServiceResponse | None = None
|
||||
result: ServiceResponse = None
|
||||
if task is not None:
|
||||
result = await task
|
||||
|
||||
|
|
|
@ -531,7 +531,7 @@ async def test_register_entity_service(hass: HomeAssistant) -> None:
|
|||
|
||||
|
||||
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
||||
"""Test an enttiy service that does not support response data."""
|
||||
"""Test an entity service that does support response data."""
|
||||
entity = MockEntity(entity_id=f"{DOMAIN}.entity")
|
||||
|
||||
async def generate_response(
|
||||
|
@ -554,24 +554,25 @@ async def test_register_entity_service_response_data(hass: HomeAssistant) -> Non
|
|||
response_data = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"hello",
|
||||
service_data={"entity_id": entity.entity_id, "some": "data"},
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert response_data == {"response-key": "response-value"}
|
||||
assert response_data == {f"{DOMAIN}.entity": {"response-key": "response-value"}}
|
||||
|
||||
|
||||
async def test_register_entity_service_response_data_multiple_matches(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test asking for service response data but matching many entities."""
|
||||
"""Test asking for service response data and matching many entities."""
|
||||
entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1")
|
||||
entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2")
|
||||
|
||||
async def generate_response(
|
||||
target: MockEntity, call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
raise ValueError("Should not be invoked")
|
||||
return {"response-key": f"response-value-{target.entity_id}"}
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
await component.async_setup({})
|
||||
|
@ -579,7 +580,80 @@ async def test_register_entity_service_response_data_multiple_matches(
|
|||
|
||||
component.async_register_entity_service(
|
||||
"hello",
|
||||
{},
|
||||
{"some": str},
|
||||
generate_response,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
||||
response_data = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"hello",
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert response_data == {
|
||||
f"{DOMAIN}.entity1": {"response-key": f"response-value-{DOMAIN}.entity1"},
|
||||
f"{DOMAIN}.entity2": {"response-key": f"response-value-{DOMAIN}.entity2"},
|
||||
}
|
||||
|
||||
|
||||
async def test_register_entity_service_response_data_multiple_matches_raises(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test asking for service response data and matching many entities raises exceptions."""
|
||||
entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1")
|
||||
entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2")
|
||||
|
||||
async def generate_response(
|
||||
target: MockEntity, call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
if target.entity_id == f"{DOMAIN}.entity1":
|
||||
raise RuntimeError("Something went wrong")
|
||||
return {"response-key": f"response-value-{target.entity_id}"}
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
await component.async_setup({})
|
||||
await component.async_add_entities([entity1, entity2])
|
||||
|
||||
component.async_register_entity_service(
|
||||
"hello",
|
||||
{"some": str},
|
||||
generate_response,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Something went wrong"):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"hello",
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_legacy_register_entity_service_response_data_multiple_matches(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test asking for legacy service response data but matching many entities."""
|
||||
entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1")
|
||||
entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2")
|
||||
|
||||
async def generate_response(
|
||||
target: MockEntity, call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
return {"response-key": "response-value"}
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
await component.async_setup({})
|
||||
await component.async_add_entities([entity1, entity2])
|
||||
|
||||
component.async_register_legacy_entity_service(
|
||||
"hello",
|
||||
{"some": str},
|
||||
generate_response,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
@ -588,6 +662,7 @@ async def test_register_entity_service_response_data_multiple_matches(
|
|||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"hello",
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
|
|
|
@ -9,7 +9,14 @@ from unittest.mock import ANY, Mock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, PERCENTAGE
|
||||
from homeassistant.core import CoreState, HomeAssistant, callback
|
||||
from homeassistant.core import (
|
||||
CoreState,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.helpers import (
|
||||
device_registry as dr,
|
||||
|
@ -1491,6 +1498,121 @@ async def test_platforms_sharing_services(hass: HomeAssistant) -> None:
|
|||
assert entity2 in entities
|
||||
|
||||
|
||||
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
||||
"""Test an entity service that does supports response data."""
|
||||
|
||||
async def generate_response(
|
||||
target: MockEntity, call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
assert call.return_response
|
||||
return {"response-key": "response-value"}
|
||||
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity = MockEntity(entity_id="mock_integration.entity")
|
||||
await entity_platform.async_add_entities([entity])
|
||||
|
||||
entity_platform.async_register_entity_service(
|
||||
"hello",
|
||||
{"some": str},
|
||||
generate_response,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
||||
response_data = await hass.services.async_call(
|
||||
"mock_platform",
|
||||
"hello",
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert response_data == {
|
||||
"mock_integration.entity": {"response-key": "response-value"}
|
||||
}
|
||||
|
||||
|
||||
async def test_register_entity_service_response_data_multiple_matches(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test an entity service that does supports response data and matching many entities."""
|
||||
|
||||
async def generate_response(
|
||||
target: MockEntity, call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
assert call.return_response
|
||||
return {"response-key": f"response-value-{target.entity_id}"}
|
||||
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity1 = MockEntity(entity_id="mock_integration.entity1")
|
||||
entity2 = MockEntity(entity_id="mock_integration.entity2")
|
||||
await entity_platform.async_add_entities([entity1, entity2])
|
||||
|
||||
entity_platform.async_register_entity_service(
|
||||
"hello",
|
||||
{"some": str},
|
||||
generate_response,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
||||
response_data = await hass.services.async_call(
|
||||
"mock_platform",
|
||||
"hello",
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert response_data == {
|
||||
"mock_integration.entity1": {
|
||||
"response-key": "response-value-mock_integration.entity1"
|
||||
},
|
||||
"mock_integration.entity2": {
|
||||
"response-key": "response-value-mock_integration.entity2"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def test_register_entity_service_response_data_multiple_matches_raises(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test entity service response matching many entities raises."""
|
||||
|
||||
async def generate_response(
|
||||
target: MockEntity, call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
assert call.return_response
|
||||
if target.entity_id == "mock_integration.entity1":
|
||||
raise RuntimeError("Something went wrong")
|
||||
return {"response-key": f"response-value-{target.entity_id}"}
|
||||
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity1 = MockEntity(entity_id="mock_integration.entity1")
|
||||
entity2 = MockEntity(entity_id="mock_integration.entity2")
|
||||
await entity_platform.async_add_entities([entity1, entity2])
|
||||
|
||||
entity_platform.async_register_entity_service(
|
||||
"hello",
|
||||
{"some": str},
|
||||
generate_response,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="Something went wrong"):
|
||||
await hass.services.async_call(
|
||||
"mock_platform",
|
||||
"hello",
|
||||
service_data={"some": "data"},
|
||||
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
|
||||
"""Test specifying an invalid entity id."""
|
||||
platform = MockEntityPlatform(hass)
|
||||
|
|
Loading…
Add table
Reference in a new issue