Support multiple responses for service calls (#96370)

* add supports_response to platform entity services

* support multiple entities in entity_service_call

* support legacy response format for service calls

* revert changes to script/shell_command

* add back test for multiple responses for legacy service

* remove SupportsResponse.ONLY_LEGACY

* Apply suggestion

Co-authored-by: Allen Porter <allen.porter@gmail.com>

* test for entity_id remove None

* revert Apply suggestion

* return EntityServiceResponse from _handle_entity_call

* Use asyncio.gather

* EntityServiceResponse not Optional

* styling

---------

Co-authored-by: Allen Porter <allen.porter@gmail.com>
This commit is contained in:
Kevin Stillhammer 2023-11-03 02:37:35 +01:00 committed by GitHub
parent b86f3be510
commit 06c9719cd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 277 additions and 37 deletions

View file

@ -300,7 +300,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_create_event,
required_features=[CalendarEntityFeature.CREATE_EVENT],
)
component.async_register_entity_service(
component.async_register_legacy_entity_service(
SERVICE_LIST_EVENTS,
SERVICE_LIST_EVENTS_SCHEMA,
async_list_events_service,

View file

@ -210,7 +210,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[WeatherEntity](
_LOGGER, DOMAIN, hass, SCAN_INTERVAL
)
component.async_register_entity_service(
component.async_register_legacy_entity_service(
SERVICE_GET_FORECAST,
{vol.Required("type"): vol.In(("daily", "hourly", "twice_daily"))},
async_get_forecast_service,

View file

@ -134,6 +134,7 @@ DOMAIN = "homeassistant"
BLOCK_LOG_TIMEOUT = 60
ServiceResponse = JsonObjectType | None
EntityServiceResponse = dict[str, ServiceResponse]
class ConfigSource(enum.StrEnum):
@ -1773,7 +1774,10 @@ class Service:
def __init__(
self,
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None],
func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] | None,
],
schema: vol.Schema | None,
domain: str,
service: str,
@ -1882,7 +1886,8 @@ class ServiceRegistry:
domain: str,
service: str,
service_func: Callable[
[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None
[ServiceCall],
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] | None,
],
schema: vol.Schema | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,

View file

@ -20,6 +20,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import (
EntityServiceResponse,
Event,
HomeAssistant,
ServiceCall,
@ -217,6 +218,40 @@ class EntityComponent(Generic[_EntityT]):
self.hass, self.entities, service_call, expand_group
)
@callback
def async_register_legacy_entity_service(
self,
name: str,
schema: dict[str | vol.Marker, Any] | vol.Schema,
func: str | Callable[..., Any],
required_features: list[int] | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Register an entity service with a legacy response format."""
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(
call: ServiceCall,
) -> ServiceResponse:
"""Handle the service."""
result = await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features
)
if result:
if len(result) > 1:
raise HomeAssistantError(
"Deprecated service call matched more than one entity"
)
return result.popitem()[1]
return None
self.hass.services.async_register(
self.domain, name, handle_service, schema, supports_response
)
@callback
def async_register_entity_service(
self,
@ -230,7 +265,9 @@ class EntityComponent(Generic[_EntityT]):
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call: ServiceCall) -> ServiceResponse:
async def handle_service(
call: ServiceCall,
) -> EntityServiceResponse | None:
"""Handle the service."""
return await service.entity_service_call(
self.hass, self._platforms.values(), func, call, required_features

View file

@ -20,8 +20,10 @@ from homeassistant.core import (
CALLBACK_TYPE,
DOMAIN as HOMEASSISTANT_DOMAIN,
CoreState,
EntityServiceResponse,
HomeAssistant,
ServiceCall,
SupportsResponse,
callback,
split_entity_id,
valid_entity_id,
@ -814,6 +816,7 @@ class EntityPlatform:
schema: dict[str, Any] | vol.Schema,
func: str | Callable[..., Any],
required_features: Iterable[int] | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Register an entity service.
@ -825,9 +828,9 @@ class EntityPlatform:
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call: ServiceCall) -> None:
async def handle_service(call: ServiceCall) -> EntityServiceResponse | None:
"""Handle the service."""
await service.entity_service_call(
return await service.entity_service_call(
self.hass,
[
plf
@ -840,7 +843,7 @@ class EntityPlatform:
)
self.hass.services.async_register(
self.platform_name, name, handle_service, schema
self.platform_name, name, handle_service, schema, supports_response
)
async def _update_entity_states(self, now: datetime) -> None:

View file

@ -28,6 +28,7 @@ from homeassistant.const import (
)
from homeassistant.core import (
Context,
EntityServiceResponse,
HomeAssistant,
ServiceCall,
ServiceResponse,
@ -790,7 +791,7 @@ async def entity_service_call(
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
call: ServiceCall,
required_features: Iterable[int] | None = None,
) -> ServiceResponse | None:
) -> EntityServiceResponse | None:
"""Handle an entity service call.
Calls all platforms simultaneously.
@ -870,10 +871,9 @@ async def entity_service_call(
return None
if len(entities) == 1:
# Single entity case avoids creating tasks and allows returning
# ServiceResponse
# Single entity case avoids creating task
entity = entities[0]
response_data = await _handle_entity_call(
single_response = await _handle_entity_call(
hass, entity, func, data, call.context
)
if entity.should_poll:
@ -881,27 +881,25 @@ async def entity_service_call(
# Set context again so it's there when we update
entity.async_set_context(call.context)
await entity.async_update_ha_state(True)
return response_data if return_response else None
return {entity.entity_id: single_response} if return_response else None
if return_response:
raise HomeAssistantError(
"Service call requested response data but matched more than one entity"
)
done, pending = await asyncio.wait(
[
asyncio.create_task(
entity.async_request_call(
_handle_entity_call(hass, entity, func, data, call.context)
)
# Use asyncio.gather here to ensure the returned results
# are in the same order as the entities list
results: list[ServiceResponse] = await asyncio.gather(
*[
entity.async_request_call(
_handle_entity_call(hass, entity, func, data, call.context)
)
for entity in entities
]
],
return_exceptions=True,
)
assert not pending
for task in done:
task.result() # pop exception if have
response_data: EntityServiceResponse = {}
for entity, result in zip(entities, results):
if isinstance(result, Exception):
raise result
response_data[entity.entity_id] = result
tasks: list[asyncio.Task[None]] = []
@ -920,7 +918,7 @@ async def entity_service_call(
for future in done:
future.result() # pop exception if have
return None
return response_data if return_response and response_data else None
async def _handle_entity_call(
@ -943,7 +941,7 @@ async def _handle_entity_call(
# Guard because callback functions do not return a task when passed to
# async_run_job.
result: ServiceResponse | None = None
result: ServiceResponse = None
if task is not None:
result = await task

View file

@ -531,7 +531,7 @@ async def test_register_entity_service(hass: HomeAssistant) -> None:
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
"""Test an enttiy service that does not support response data."""
"""Test an entity service that does support response data."""
entity = MockEntity(entity_id=f"{DOMAIN}.entity")
async def generate_response(
@ -554,24 +554,25 @@ async def test_register_entity_service_response_data(hass: HomeAssistant) -> Non
response_data = await hass.services.async_call(
DOMAIN,
"hello",
service_data={"entity_id": entity.entity_id, "some": "data"},
service_data={"some": "data"},
target={"entity_id": [entity.entity_id]},
blocking=True,
return_response=True,
)
assert response_data == {"response-key": "response-value"}
assert response_data == {f"{DOMAIN}.entity": {"response-key": "response-value"}}
async def test_register_entity_service_response_data_multiple_matches(
hass: HomeAssistant,
) -> None:
"""Test asking for service response data but matching many entities."""
"""Test asking for service response data and matching many entities."""
entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1")
entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2")
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
raise ValueError("Should not be invoked")
return {"response-key": f"response-value-{target.entity_id}"}
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({})
@ -579,7 +580,80 @@ async def test_register_entity_service_response_data_multiple_matches(
component.async_register_entity_service(
"hello",
{},
{"some": str},
generate_response,
supports_response=SupportsResponse.ONLY,
)
response_data = await hass.services.async_call(
DOMAIN,
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,
)
assert response_data == {
f"{DOMAIN}.entity1": {"response-key": f"response-value-{DOMAIN}.entity1"},
f"{DOMAIN}.entity2": {"response-key": f"response-value-{DOMAIN}.entity2"},
}
async def test_register_entity_service_response_data_multiple_matches_raises(
hass: HomeAssistant,
) -> None:
"""Test asking for service response data and matching many entities raises exceptions."""
entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1")
entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2")
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
if target.entity_id == f"{DOMAIN}.entity1":
raise RuntimeError("Something went wrong")
return {"response-key": f"response-value-{target.entity_id}"}
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({})
await component.async_add_entities([entity1, entity2])
component.async_register_entity_service(
"hello",
{"some": str},
generate_response,
supports_response=SupportsResponse.ONLY,
)
with pytest.raises(RuntimeError, match="Something went wrong"):
await hass.services.async_call(
DOMAIN,
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,
)
async def test_legacy_register_entity_service_response_data_multiple_matches(
hass: HomeAssistant,
) -> None:
"""Test asking for legacy service response data but matching many entities."""
entity1 = MockEntity(entity_id=f"{DOMAIN}.entity1")
entity2 = MockEntity(entity_id=f"{DOMAIN}.entity2")
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
return {"response-key": "response-value"}
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({})
await component.async_add_entities([entity1, entity2])
component.async_register_legacy_entity_service(
"hello",
{"some": str},
generate_response,
supports_response=SupportsResponse.ONLY,
)
@ -588,6 +662,7 @@ async def test_register_entity_service_response_data_multiple_matches(
await hass.services.async_call(
DOMAIN,
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,

View file

@ -9,7 +9,14 @@ from unittest.mock import ANY, Mock, patch
import pytest
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, PERCENTAGE
from homeassistant.core import CoreState, HomeAssistant, callback
from homeassistant.core import (
CoreState,
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import (
device_registry as dr,
@ -1491,6 +1498,121 @@ async def test_platforms_sharing_services(hass: HomeAssistant) -> None:
assert entity2 in entities
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
"""Test an entity service that does supports response data."""
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
return {"response-key": "response-value"}
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity = MockEntity(entity_id="mock_integration.entity")
await entity_platform.async_add_entities([entity])
entity_platform.async_register_entity_service(
"hello",
{"some": str},
generate_response,
supports_response=SupportsResponse.ONLY,
)
response_data = await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"entity_id": [entity.entity_id]},
blocking=True,
return_response=True,
)
assert response_data == {
"mock_integration.entity": {"response-key": "response-value"}
}
async def test_register_entity_service_response_data_multiple_matches(
hass: HomeAssistant,
) -> None:
"""Test an entity service that does supports response data and matching many entities."""
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
return {"response-key": f"response-value-{target.entity_id}"}
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity1")
entity2 = MockEntity(entity_id="mock_integration.entity2")
await entity_platform.async_add_entities([entity1, entity2])
entity_platform.async_register_entity_service(
"hello",
{"some": str},
generate_response,
supports_response=SupportsResponse.ONLY,
)
response_data = await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,
)
assert response_data == {
"mock_integration.entity1": {
"response-key": "response-value-mock_integration.entity1"
},
"mock_integration.entity2": {
"response-key": "response-value-mock_integration.entity2"
},
}
async def test_register_entity_service_response_data_multiple_matches_raises(
hass: HomeAssistant,
) -> None:
"""Test entity service response matching many entities raises."""
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
if target.entity_id == "mock_integration.entity1":
raise RuntimeError("Something went wrong")
return {"response-key": f"response-value-{target.entity_id}"}
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity1")
entity2 = MockEntity(entity_id="mock_integration.entity2")
await entity_platform.async_add_entities([entity1, entity2])
entity_platform.async_register_entity_service(
"hello",
{"some": str},
generate_response,
supports_response=SupportsResponse.ONLY,
)
with pytest.raises(RuntimeError, match="Something went wrong"):
await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,
)
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
"""Test specifying an invalid entity id."""
platform = MockEntityPlatform(hass)