Add support for services to return data (#94401)
* Add support for service calls with resopnse data. Update the service calls to allow returning responses with data, with an initial use case supporting basic service calls usable within script. * Revert enttiy platform/component changes * Remove unnecessary comma diff * Revert additional unnecessary changes * Simplify service call * Simplify and fix typing and revert whitespace * Clarify typing intent * Revert more entity service calls * Revert additional entity service changes * Set blocking=True for group notify service call * Revert unnecessary changes * Reverting more whitespace changes * Revert more service changes * Add test coverage for None return case * Add parameter to service calls indicating return values were requested * Update tests/test_core.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Add additional service call tests * Update test comment --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
12129e9d21
commit
84c66b3cad
3 changed files with 184 additions and 22 deletions
|
@ -66,7 +66,7 @@ class GroupNotifyPlatform(BaseNotificationService):
|
|||
payload: dict[str, Any] = {ATTR_MESSAGE: message}
|
||||
payload.update({key: val for key, val in kwargs.items() if val})
|
||||
|
||||
tasks: list[asyncio.Task[bool | None]] = []
|
||||
tasks: list[asyncio.Task[Any]] = []
|
||||
for entity in self.entities:
|
||||
sending_payload = deepcopy(payload.copy())
|
||||
if (default_data := entity.get(ATTR_DATA)) is not None:
|
||||
|
@ -74,7 +74,7 @@ class GroupNotifyPlatform(BaseNotificationService):
|
|||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self.hass.services.async_call(
|
||||
DOMAIN, entity[ATTR_SERVICE], sending_payload
|
||||
DOMAIN, entity[ATTR_SERVICE], sending_payload, blocking=True
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -88,6 +88,7 @@ from .util.async_ import (
|
|||
run_callback_threadsafe,
|
||||
shutdown_run_callback_threadsafe,
|
||||
)
|
||||
from .util.json import JsonObjectType
|
||||
from .util.read_only_dict import ReadOnlyDict
|
||||
from .util.timeout import TimeoutManager
|
||||
from .util.unit_system import (
|
||||
|
@ -130,6 +131,8 @@ DOMAIN = "homeassistant"
|
|||
# How long to wait to log tasks that are blocking
|
||||
BLOCK_LOG_TIMEOUT = 60
|
||||
|
||||
ServiceResult = JsonObjectType | None
|
||||
|
||||
|
||||
class ConfigSource(StrEnum):
|
||||
"""Source of core configuration."""
|
||||
|
@ -1659,7 +1662,7 @@ class Service:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
|
||||
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
|
||||
schema: vol.Schema | None,
|
||||
domain: str,
|
||||
service: str,
|
||||
|
@ -1673,7 +1676,7 @@ class Service:
|
|||
class ServiceCall:
|
||||
"""Representation of a call to a service."""
|
||||
|
||||
__slots__ = ["domain", "service", "data", "context"]
|
||||
__slots__ = ["domain", "service", "data", "context", "return_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1681,12 +1684,14 @@ class ServiceCall:
|
|||
service: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
context: Context | None = None,
|
||||
return_values: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a service call."""
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = ReadOnlyDict(data or {})
|
||||
self.context = context or Context()
|
||||
self.return_values = return_values
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation of the service."""
|
||||
|
@ -1731,7 +1736,10 @@ class ServiceRegistry:
|
|||
self,
|
||||
domain: str,
|
||||
service: str,
|
||||
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
|
||||
service_func: Callable[
|
||||
[ServiceCall],
|
||||
Coroutine[Any, Any, ServiceResult] | None,
|
||||
],
|
||||
schema: vol.Schema | None = None,
|
||||
) -> None:
|
||||
"""Register a service.
|
||||
|
@ -1747,7 +1755,9 @@ class ServiceRegistry:
|
|||
self,
|
||||
domain: str,
|
||||
service: str,
|
||||
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
|
||||
service_func: Callable[
|
||||
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
|
||||
],
|
||||
schema: vol.Schema | None = None,
|
||||
) -> None:
|
||||
"""Register a service.
|
||||
|
@ -1805,13 +1815,22 @@ class ServiceRegistry:
|
|||
blocking: bool = False,
|
||||
context: Context | None = None,
|
||||
target: dict[str, Any] | None = None,
|
||||
) -> bool | None:
|
||||
return_values: bool = False,
|
||||
) -> ServiceResult:
|
||||
"""Call a service.
|
||||
|
||||
See description of async_call for details.
|
||||
"""
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.async_call(domain, service, service_data, blocking, context, target),
|
||||
self.async_call(
|
||||
domain,
|
||||
service,
|
||||
service_data,
|
||||
blocking,
|
||||
context,
|
||||
target,
|
||||
return_values,
|
||||
),
|
||||
self._hass.loop,
|
||||
).result()
|
||||
|
||||
|
@ -1823,11 +1842,16 @@ class ServiceRegistry:
|
|||
blocking: bool = False,
|
||||
context: Context | None = None,
|
||||
target: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
return_values: bool = False,
|
||||
) -> ServiceResult:
|
||||
"""Call a service.
|
||||
|
||||
Specify blocking=True to wait until service is executed.
|
||||
|
||||
If return_values=True, indicates that the caller can consume return values
|
||||
from the service, if any. Return values are a dict that can be returned by the
|
||||
standard JSON serialization process. Return values can only be used with blocking=True.
|
||||
|
||||
This method will fire an event to indicate the service has been called.
|
||||
|
||||
Because the service is sent as an event you are not allowed to use
|
||||
|
@ -1840,6 +1864,9 @@ class ServiceRegistry:
|
|||
context = context or Context()
|
||||
service_data = service_data or {}
|
||||
|
||||
if return_values and not blocking:
|
||||
raise ValueError("Invalid argument return_values=True when blocking=False")
|
||||
|
||||
try:
|
||||
handler = self._services[domain][service]
|
||||
except KeyError:
|
||||
|
@ -1862,7 +1889,9 @@ class ServiceRegistry:
|
|||
else:
|
||||
processed_data = service_data
|
||||
|
||||
service_call = ServiceCall(domain, service, processed_data, context)
|
||||
service_call = ServiceCall(
|
||||
domain, service, processed_data, context, return_values
|
||||
)
|
||||
|
||||
self._hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
|
@ -1877,13 +1906,20 @@ class ServiceRegistry:
|
|||
coro = self._execute_service(handler, service_call)
|
||||
if not blocking:
|
||||
self._run_service_in_background(coro, service_call)
|
||||
return
|
||||
return None
|
||||
|
||||
await coro
|
||||
response_data = await coro
|
||||
if not return_values:
|
||||
return None
|
||||
if not isinstance(response_data, dict):
|
||||
raise HomeAssistantError(
|
||||
f"Service response data expected a dictionary, was {type(response_data)}"
|
||||
)
|
||||
return response_data
|
||||
|
||||
def _run_service_in_background(
|
||||
self,
|
||||
coro_or_task: Coroutine[Any, Any, None] | asyncio.Task[None],
|
||||
coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any],
|
||||
service_call: ServiceCall,
|
||||
) -> None:
|
||||
"""Run service call in background, catching and logging any exceptions."""
|
||||
|
@ -1909,18 +1945,21 @@ class ServiceRegistry:
|
|||
|
||||
async def _execute_service(
|
||||
self, handler: Service, service_call: ServiceCall
|
||||
) -> None:
|
||||
) -> ServiceResult:
|
||||
"""Execute a service."""
|
||||
if handler.job.job_type == HassJobType.Coroutinefunction:
|
||||
await cast(Callable[[ServiceCall], Awaitable[None]], handler.job.target)(
|
||||
return await cast(
|
||||
Callable[[ServiceCall], Awaitable[ServiceResult]],
|
||||
handler.job.target,
|
||||
)(service_call)
|
||||
if handler.job.job_type == HassJobType.Callback:
|
||||
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
|
||||
service_call
|
||||
)
|
||||
elif handler.job.job_type == HassJobType.Callback:
|
||||
cast(Callable[[ServiceCall], None], handler.job.target)(service_call)
|
||||
else:
|
||||
await self._hass.async_add_executor_job(
|
||||
cast(Callable[[ServiceCall], None], handler.job.target), service_call
|
||||
)
|
||||
return await self._hass.async_add_executor_job(
|
||||
cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
|
||||
service_call,
|
||||
)
|
||||
|
||||
|
||||
class Config:
|
||||
|
|
|
@ -33,8 +33,9 @@ from homeassistant.const import (
|
|||
__version__,
|
||||
)
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.core import HassJob, HomeAssistant, State
|
||||
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
InvalidEntityFormatError,
|
||||
InvalidStateError,
|
||||
MaxLengthExceeded,
|
||||
|
@ -1082,6 +1083,128 @@ async def test_serviceregistry_callback_service_raise_exception(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None:
|
||||
"""Test service call for a service that has return values."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> ServiceResult:
|
||||
"""Service handler coroutine."""
|
||||
assert call.return_values
|
||||
return {"test-reply": "test-value1"}
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
)
|
||||
result = await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_values=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result == {"test-reply": "test-value1"}
|
||||
|
||||
|
||||
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None:
|
||||
"""Test service call for an async service that has return values."""
|
||||
|
||||
async def service_handler(call: ServiceCall) -> ServiceResult:
|
||||
"""Service handler coroutine."""
|
||||
assert call.return_values
|
||||
return {"test-reply": "test-value1"}
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
)
|
||||
result = await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_values=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result == {"test-reply": "test-value1"}
|
||||
|
||||
|
||||
async def test_services_call_return_values_requires_blocking(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that non-blocking service calls cannot return values."""
|
||||
async_mock_service(hass, "test_domain", "test_service")
|
||||
with pytest.raises(ValueError, match="when blocking=False"):
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=False,
|
||||
return_values=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("return_value", "expected_error"),
|
||||
[
|
||||
(True, "expected a dictionary"),
|
||||
(False, "expected a dictionary"),
|
||||
(None, "expected a dictionary"),
|
||||
("some-value", "expected a dictionary"),
|
||||
(["some-list"], "expected a dictionary"),
|
||||
],
|
||||
)
|
||||
async def test_serviceregistry_return_values_invalid(
|
||||
hass: HomeAssistant, return_value: Any, expected_error: str
|
||||
) -> None:
|
||||
"""Test service call return values are not returned when there is no result schema."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> ServiceResult:
|
||||
"""Service handler coroutine."""
|
||||
assert call.return_values
|
||||
return return_value
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
)
|
||||
with pytest.raises(HomeAssistantError, match=expected_error):
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
return_values=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None:
|
||||
"""Test service call data when not asked for return values."""
|
||||
|
||||
def service_handler(call: ServiceCall) -> None:
|
||||
"""Service handler coroutine."""
|
||||
assert not call.return_values
|
||||
return
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_handler,
|
||||
)
|
||||
result = await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
service_data={},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert not result
|
||||
|
||||
|
||||
async def test_config_defaults() -> None:
|
||||
"""Test config defaults."""
|
||||
hass = Mock()
|
||||
|
|
Loading…
Add table
Reference in a new issue