Add support for services to return data (#94401)

* Add support for service calls with resopnse data.

Update the service calls to allow returning responses with data,
with an initial use case supporting basic service calls usable
within script.

* Revert enttiy platform/component changes

* Remove unnecessary comma diff

* Revert additional unnecessary changes

* Simplify service call

* Simplify and fix typing and revert whitespace

* Clarify typing intent

* Revert more entity service calls

* Revert additional entity service changes

* Set blocking=True for group notify service call

* Revert unnecessary changes

* Reverting more whitespace changes

* Revert more service changes

* Add test coverage for None return case

* Add parameter to service calls indicating return values were requested

* Update tests/test_core.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Add additional service call tests

* Update test comment

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2023-06-16 09:43:35 -07:00 committed by GitHub
parent 12129e9d21
commit 84c66b3cad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 184 additions and 22 deletions

View file

@ -66,7 +66,7 @@ class GroupNotifyPlatform(BaseNotificationService):
payload: dict[str, Any] = {ATTR_MESSAGE: message}
payload.update({key: val for key, val in kwargs.items() if val})
tasks: list[asyncio.Task[bool | None]] = []
tasks: list[asyncio.Task[Any]] = []
for entity in self.entities:
sending_payload = deepcopy(payload.copy())
if (default_data := entity.get(ATTR_DATA)) is not None:
@ -74,7 +74,7 @@ class GroupNotifyPlatform(BaseNotificationService):
tasks.append(
asyncio.create_task(
self.hass.services.async_call(
DOMAIN, entity[ATTR_SERVICE], sending_payload
DOMAIN, entity[ATTR_SERVICE], sending_payload, blocking=True
)
)
)

View file

@ -88,6 +88,7 @@ from .util.async_ import (
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
from .util.json import JsonObjectType
from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager
from .util.unit_system import (
@ -130,6 +131,8 @@ DOMAIN = "homeassistant"
# How long to wait to log tasks that are blocking
BLOCK_LOG_TIMEOUT = 60
ServiceResult = JsonObjectType | None
class ConfigSource(StrEnum):
"""Source of core configuration."""
@ -1659,7 +1662,7 @@ class Service:
def __init__(
self,
func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
schema: vol.Schema | None,
domain: str,
service: str,
@ -1673,7 +1676,7 @@ class Service:
class ServiceCall:
"""Representation of a call to a service."""
__slots__ = ["domain", "service", "data", "context"]
__slots__ = ["domain", "service", "data", "context", "return_values"]
def __init__(
self,
@ -1681,12 +1684,14 @@ class ServiceCall:
service: str,
data: dict[str, Any] | None = None,
context: Context | None = None,
return_values: bool = False,
) -> None:
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = ReadOnlyDict(data or {})
self.context = context or Context()
self.return_values = return_values
def __repr__(self) -> str:
"""Return the representation of the service."""
@ -1731,7 +1736,10 @@ class ServiceRegistry:
self,
domain: str,
service: str,
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResult] | None,
],
schema: vol.Schema | None = None,
) -> None:
"""Register a service.
@ -1747,7 +1755,9 @@ class ServiceRegistry:
self,
domain: str,
service: str,
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
service_func: Callable[
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
],
schema: vol.Schema | None = None,
) -> None:
"""Register a service.
@ -1805,13 +1815,22 @@ class ServiceRegistry:
blocking: bool = False,
context: Context | None = None,
target: dict[str, Any] | None = None,
) -> bool | None:
return_values: bool = False,
) -> ServiceResult:
"""Call a service.
See description of async_call for details.
"""
return asyncio.run_coroutine_threadsafe(
self.async_call(domain, service, service_data, blocking, context, target),
self.async_call(
domain,
service,
service_data,
blocking,
context,
target,
return_values,
),
self._hass.loop,
).result()
@ -1823,11 +1842,16 @@ class ServiceRegistry:
blocking: bool = False,
context: Context | None = None,
target: dict[str, Any] | None = None,
) -> None:
return_values: bool = False,
) -> ServiceResult:
"""Call a service.
Specify blocking=True to wait until service is executed.
If return_values=True, indicates that the caller can consume return values
from the service, if any. Return values are a dict that can be returned by the
standard JSON serialization process. Return values can only be used with blocking=True.
This method will fire an event to indicate the service has been called.
Because the service is sent as an event you are not allowed to use
@ -1840,6 +1864,9 @@ class ServiceRegistry:
context = context or Context()
service_data = service_data or {}
if return_values and not blocking:
raise ValueError("Invalid argument return_values=True when blocking=False")
try:
handler = self._services[domain][service]
except KeyError:
@ -1862,7 +1889,9 @@ class ServiceRegistry:
else:
processed_data = service_data
service_call = ServiceCall(domain, service, processed_data, context)
service_call = ServiceCall(
domain, service, processed_data, context, return_values
)
self._hass.bus.async_fire(
EVENT_CALL_SERVICE,
@ -1877,13 +1906,20 @@ class ServiceRegistry:
coro = self._execute_service(handler, service_call)
if not blocking:
self._run_service_in_background(coro, service_call)
return
return None
await coro
response_data = await coro
if not return_values:
return None
if not isinstance(response_data, dict):
raise HomeAssistantError(
f"Service response data expected a dictionary, was {type(response_data)}"
)
return response_data
def _run_service_in_background(
self,
coro_or_task: Coroutine[Any, Any, None] | asyncio.Task[None],
coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any],
service_call: ServiceCall,
) -> None:
"""Run service call in background, catching and logging any exceptions."""
@ -1909,18 +1945,21 @@ class ServiceRegistry:
async def _execute_service(
self, handler: Service, service_call: ServiceCall
) -> None:
) -> ServiceResult:
"""Execute a service."""
if handler.job.job_type == HassJobType.Coroutinefunction:
await cast(Callable[[ServiceCall], Awaitable[None]], handler.job.target)(
return await cast(
Callable[[ServiceCall], Awaitable[ServiceResult]],
handler.job.target,
)(service_call)
if handler.job.job_type == HassJobType.Callback:
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
service_call
)
elif handler.job.job_type == HassJobType.Callback:
cast(Callable[[ServiceCall], None], handler.job.target)(service_call)
else:
await self._hass.async_add_executor_job(
cast(Callable[[ServiceCall], None], handler.job.target), service_call
)
return await self._hass.async_add_executor_job(
cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
service_call,
)
class Config:

View file

@ -33,8 +33,9 @@ from homeassistant.const import (
__version__,
)
import homeassistant.core as ha
from homeassistant.core import HassJob, HomeAssistant, State
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State
from homeassistant.exceptions import (
HomeAssistantError,
InvalidEntityFormatError,
InvalidStateError,
MaxLengthExceeded,
@ -1082,6 +1083,128 @@ async def test_serviceregistry_callback_service_raise_exception(
await hass.async_block_till_done()
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None:
"""Test service call for a service that has return values."""
def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None:
"""Test service call for an async service that has return values."""
async def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_services_call_return_values_requires_blocking(
hass: HomeAssistant,
) -> None:
"""Test that non-blocking service calls cannot return values."""
async_mock_service(hass, "test_domain", "test_service")
with pytest.raises(ValueError, match="when blocking=False"):
await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=False,
return_values=True,
)
@pytest.mark.parametrize(
("return_value", "expected_error"),
[
(True, "expected a dictionary"),
(False, "expected a dictionary"),
(None, "expected a dictionary"),
("some-value", "expected a dictionary"),
(["some-list"], "expected a dictionary"),
],
)
async def test_serviceregistry_return_values_invalid(
hass: HomeAssistant, return_value: Any, expected_error: str
) -> None:
"""Test service call return values are not returned when there is no result schema."""
def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return return_value
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
with pytest.raises(HomeAssistantError, match=expected_error):
await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None:
"""Test service call data when not asked for return values."""
def service_handler(call: ServiceCall) -> None:
"""Service handler coroutine."""
assert not call.return_values
return
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
)
await hass.async_block_till_done()
assert not result
async def test_config_defaults() -> None:
"""Test config defaults."""
hass = Mock()