Add support for services to return data (#94401)
* Add support for service calls with resopnse data. Update the service calls to allow returning responses with data, with an initial use case supporting basic service calls usable within script. * Revert enttiy platform/component changes * Remove unnecessary comma diff * Revert additional unnecessary changes * Simplify service call * Simplify and fix typing and revert whitespace * Clarify typing intent * Revert more entity service calls * Revert additional entity service changes * Set blocking=True for group notify service call * Revert unnecessary changes * Reverting more whitespace changes * Revert more service changes * Add test coverage for None return case * Add parameter to service calls indicating return values were requested * Update tests/test_core.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Add additional service call tests * Update test comment --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
12129e9d21
commit
84c66b3cad
3 changed files with 184 additions and 22 deletions
|
@ -66,7 +66,7 @@ class GroupNotifyPlatform(BaseNotificationService):
|
||||||
payload: dict[str, Any] = {ATTR_MESSAGE: message}
|
payload: dict[str, Any] = {ATTR_MESSAGE: message}
|
||||||
payload.update({key: val for key, val in kwargs.items() if val})
|
payload.update({key: val for key, val in kwargs.items() if val})
|
||||||
|
|
||||||
tasks: list[asyncio.Task[bool | None]] = []
|
tasks: list[asyncio.Task[Any]] = []
|
||||||
for entity in self.entities:
|
for entity in self.entities:
|
||||||
sending_payload = deepcopy(payload.copy())
|
sending_payload = deepcopy(payload.copy())
|
||||||
if (default_data := entity.get(ATTR_DATA)) is not None:
|
if (default_data := entity.get(ATTR_DATA)) is not None:
|
||||||
|
@ -74,7 +74,7 @@ class GroupNotifyPlatform(BaseNotificationService):
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.hass.services.async_call(
|
self.hass.services.async_call(
|
||||||
DOMAIN, entity[ATTR_SERVICE], sending_payload
|
DOMAIN, entity[ATTR_SERVICE], sending_payload, blocking=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -88,6 +88,7 @@ from .util.async_ import (
|
||||||
run_callback_threadsafe,
|
run_callback_threadsafe,
|
||||||
shutdown_run_callback_threadsafe,
|
shutdown_run_callback_threadsafe,
|
||||||
)
|
)
|
||||||
|
from .util.json import JsonObjectType
|
||||||
from .util.read_only_dict import ReadOnlyDict
|
from .util.read_only_dict import ReadOnlyDict
|
||||||
from .util.timeout import TimeoutManager
|
from .util.timeout import TimeoutManager
|
||||||
from .util.unit_system import (
|
from .util.unit_system import (
|
||||||
|
@ -130,6 +131,8 @@ DOMAIN = "homeassistant"
|
||||||
# How long to wait to log tasks that are blocking
|
# How long to wait to log tasks that are blocking
|
||||||
BLOCK_LOG_TIMEOUT = 60
|
BLOCK_LOG_TIMEOUT = 60
|
||||||
|
|
||||||
|
ServiceResult = JsonObjectType | None
|
||||||
|
|
||||||
|
|
||||||
class ConfigSource(StrEnum):
|
class ConfigSource(StrEnum):
|
||||||
"""Source of core configuration."""
|
"""Source of core configuration."""
|
||||||
|
@ -1659,7 +1662,7 @@ class Service:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
|
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
|
||||||
schema: vol.Schema | None,
|
schema: vol.Schema | None,
|
||||||
domain: str,
|
domain: str,
|
||||||
service: str,
|
service: str,
|
||||||
|
@ -1673,7 +1676,7 @@ class Service:
|
||||||
class ServiceCall:
|
class ServiceCall:
|
||||||
"""Representation of a call to a service."""
|
"""Representation of a call to a service."""
|
||||||
|
|
||||||
__slots__ = ["domain", "service", "data", "context"]
|
__slots__ = ["domain", "service", "data", "context", "return_values"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1681,12 +1684,14 @@ class ServiceCall:
|
||||||
service: str,
|
service: str,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
context: Context | None = None,
|
context: Context | None = None,
|
||||||
|
return_values: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a service call."""
|
"""Initialize a service call."""
|
||||||
self.domain = domain.lower()
|
self.domain = domain.lower()
|
||||||
self.service = service.lower()
|
self.service = service.lower()
|
||||||
self.data = ReadOnlyDict(data or {})
|
self.data = ReadOnlyDict(data or {})
|
||||||
self.context = context or Context()
|
self.context = context or Context()
|
||||||
|
self.return_values = return_values
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return the representation of the service."""
|
"""Return the representation of the service."""
|
||||||
|
@ -1731,7 +1736,10 @@ class ServiceRegistry:
|
||||||
self,
|
self,
|
||||||
domain: str,
|
domain: str,
|
||||||
service: str,
|
service: str,
|
||||||
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
|
service_func: Callable[
|
||||||
|
[ServiceCall],
|
||||||
|
Coroutine[Any, Any, ServiceResult] | None,
|
||||||
|
],
|
||||||
schema: vol.Schema | None = None,
|
schema: vol.Schema | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a service.
|
"""Register a service.
|
||||||
|
@ -1747,7 +1755,9 @@ class ServiceRegistry:
|
||||||
self,
|
self,
|
||||||
domain: str,
|
domain: str,
|
||||||
service: str,
|
service: str,
|
||||||
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None],
|
service_func: Callable[
|
||||||
|
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
|
||||||
|
],
|
||||||
schema: vol.Schema | None = None,
|
schema: vol.Schema | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a service.
|
"""Register a service.
|
||||||
|
@ -1805,13 +1815,22 @@ class ServiceRegistry:
|
||||||
blocking: bool = False,
|
blocking: bool = False,
|
||||||
context: Context | None = None,
|
context: Context | None = None,
|
||||||
target: dict[str, Any] | None = None,
|
target: dict[str, Any] | None = None,
|
||||||
) -> bool | None:
|
return_values: bool = False,
|
||||||
|
) -> ServiceResult:
|
||||||
"""Call a service.
|
"""Call a service.
|
||||||
|
|
||||||
See description of async_call for details.
|
See description of async_call for details.
|
||||||
"""
|
"""
|
||||||
return asyncio.run_coroutine_threadsafe(
|
return asyncio.run_coroutine_threadsafe(
|
||||||
self.async_call(domain, service, service_data, blocking, context, target),
|
self.async_call(
|
||||||
|
domain,
|
||||||
|
service,
|
||||||
|
service_data,
|
||||||
|
blocking,
|
||||||
|
context,
|
||||||
|
target,
|
||||||
|
return_values,
|
||||||
|
),
|
||||||
self._hass.loop,
|
self._hass.loop,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
|
@ -1823,11 +1842,16 @@ class ServiceRegistry:
|
||||||
blocking: bool = False,
|
blocking: bool = False,
|
||||||
context: Context | None = None,
|
context: Context | None = None,
|
||||||
target: dict[str, Any] | None = None,
|
target: dict[str, Any] | None = None,
|
||||||
) -> None:
|
return_values: bool = False,
|
||||||
|
) -> ServiceResult:
|
||||||
"""Call a service.
|
"""Call a service.
|
||||||
|
|
||||||
Specify blocking=True to wait until service is executed.
|
Specify blocking=True to wait until service is executed.
|
||||||
|
|
||||||
|
If return_values=True, indicates that the caller can consume return values
|
||||||
|
from the service, if any. Return values are a dict that can be returned by the
|
||||||
|
standard JSON serialization process. Return values can only be used with blocking=True.
|
||||||
|
|
||||||
This method will fire an event to indicate the service has been called.
|
This method will fire an event to indicate the service has been called.
|
||||||
|
|
||||||
Because the service is sent as an event you are not allowed to use
|
Because the service is sent as an event you are not allowed to use
|
||||||
|
@ -1840,6 +1864,9 @@ class ServiceRegistry:
|
||||||
context = context or Context()
|
context = context or Context()
|
||||||
service_data = service_data or {}
|
service_data = service_data or {}
|
||||||
|
|
||||||
|
if return_values and not blocking:
|
||||||
|
raise ValueError("Invalid argument return_values=True when blocking=False")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handler = self._services[domain][service]
|
handler = self._services[domain][service]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -1862,7 +1889,9 @@ class ServiceRegistry:
|
||||||
else:
|
else:
|
||||||
processed_data = service_data
|
processed_data = service_data
|
||||||
|
|
||||||
service_call = ServiceCall(domain, service, processed_data, context)
|
service_call = ServiceCall(
|
||||||
|
domain, service, processed_data, context, return_values
|
||||||
|
)
|
||||||
|
|
||||||
self._hass.bus.async_fire(
|
self._hass.bus.async_fire(
|
||||||
EVENT_CALL_SERVICE,
|
EVENT_CALL_SERVICE,
|
||||||
|
@ -1877,13 +1906,20 @@ class ServiceRegistry:
|
||||||
coro = self._execute_service(handler, service_call)
|
coro = self._execute_service(handler, service_call)
|
||||||
if not blocking:
|
if not blocking:
|
||||||
self._run_service_in_background(coro, service_call)
|
self._run_service_in_background(coro, service_call)
|
||||||
return
|
return None
|
||||||
|
|
||||||
await coro
|
response_data = await coro
|
||||||
|
if not return_values:
|
||||||
|
return None
|
||||||
|
if not isinstance(response_data, dict):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"Service response data expected a dictionary, was {type(response_data)}"
|
||||||
|
)
|
||||||
|
return response_data
|
||||||
|
|
||||||
def _run_service_in_background(
|
def _run_service_in_background(
|
||||||
self,
|
self,
|
||||||
coro_or_task: Coroutine[Any, Any, None] | asyncio.Task[None],
|
coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any],
|
||||||
service_call: ServiceCall,
|
service_call: ServiceCall,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run service call in background, catching and logging any exceptions."""
|
"""Run service call in background, catching and logging any exceptions."""
|
||||||
|
@ -1909,18 +1945,21 @@ class ServiceRegistry:
|
||||||
|
|
||||||
async def _execute_service(
|
async def _execute_service(
|
||||||
self, handler: Service, service_call: ServiceCall
|
self, handler: Service, service_call: ServiceCall
|
||||||
) -> None:
|
) -> ServiceResult:
|
||||||
"""Execute a service."""
|
"""Execute a service."""
|
||||||
if handler.job.job_type == HassJobType.Coroutinefunction:
|
if handler.job.job_type == HassJobType.Coroutinefunction:
|
||||||
await cast(Callable[[ServiceCall], Awaitable[None]], handler.job.target)(
|
return await cast(
|
||||||
|
Callable[[ServiceCall], Awaitable[ServiceResult]],
|
||||||
|
handler.job.target,
|
||||||
|
)(service_call)
|
||||||
|
if handler.job.job_type == HassJobType.Callback:
|
||||||
|
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
|
||||||
service_call
|
service_call
|
||||||
)
|
)
|
||||||
elif handler.job.job_type == HassJobType.Callback:
|
return await self._hass.async_add_executor_job(
|
||||||
cast(Callable[[ServiceCall], None], handler.job.target)(service_call)
|
cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
|
||||||
else:
|
service_call,
|
||||||
await self._hass.async_add_executor_job(
|
)
|
||||||
cast(Callable[[ServiceCall], None], handler.job.target), service_call
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -33,8 +33,9 @@ from homeassistant.const import (
|
||||||
__version__,
|
__version__,
|
||||||
)
|
)
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
from homeassistant.core import HassJob, HomeAssistant, State
|
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State
|
||||||
from homeassistant.exceptions import (
|
from homeassistant.exceptions import (
|
||||||
|
HomeAssistantError,
|
||||||
InvalidEntityFormatError,
|
InvalidEntityFormatError,
|
||||||
InvalidStateError,
|
InvalidStateError,
|
||||||
MaxLengthExceeded,
|
MaxLengthExceeded,
|
||||||
|
@ -1082,6 +1083,128 @@ async def test_serviceregistry_callback_service_raise_exception(
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None:
|
||||||
|
"""Test service call for a service that has return values."""
|
||||||
|
|
||||||
|
def service_handler(call: ServiceCall) -> ServiceResult:
|
||||||
|
"""Service handler coroutine."""
|
||||||
|
assert call.return_values
|
||||||
|
return {"test-reply": "test-value1"}
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_handler,
|
||||||
|
)
|
||||||
|
result = await hass.services.async_call(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_data={},
|
||||||
|
blocking=True,
|
||||||
|
return_values=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result == {"test-reply": "test-value1"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None:
|
||||||
|
"""Test service call for an async service that has return values."""
|
||||||
|
|
||||||
|
async def service_handler(call: ServiceCall) -> ServiceResult:
|
||||||
|
"""Service handler coroutine."""
|
||||||
|
assert call.return_values
|
||||||
|
return {"test-reply": "test-value1"}
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_handler,
|
||||||
|
)
|
||||||
|
result = await hass.services.async_call(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_data={},
|
||||||
|
blocking=True,
|
||||||
|
return_values=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result == {"test-reply": "test-value1"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_services_call_return_values_requires_blocking(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test that non-blocking service calls cannot return values."""
|
||||||
|
async_mock_service(hass, "test_domain", "test_service")
|
||||||
|
with pytest.raises(ValueError, match="when blocking=False"):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_data={},
|
||||||
|
blocking=False,
|
||||||
|
return_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("return_value", "expected_error"),
|
||||||
|
[
|
||||||
|
(True, "expected a dictionary"),
|
||||||
|
(False, "expected a dictionary"),
|
||||||
|
(None, "expected a dictionary"),
|
||||||
|
("some-value", "expected a dictionary"),
|
||||||
|
(["some-list"], "expected a dictionary"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_serviceregistry_return_values_invalid(
|
||||||
|
hass: HomeAssistant, return_value: Any, expected_error: str
|
||||||
|
) -> None:
|
||||||
|
"""Test service call return values are not returned when there is no result schema."""
|
||||||
|
|
||||||
|
def service_handler(call: ServiceCall) -> ServiceResult:
|
||||||
|
"""Service handler coroutine."""
|
||||||
|
assert call.return_values
|
||||||
|
return return_value
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_handler,
|
||||||
|
)
|
||||||
|
with pytest.raises(HomeAssistantError, match=expected_error):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_data={},
|
||||||
|
blocking=True,
|
||||||
|
return_values=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None:
|
||||||
|
"""Test service call data when not asked for return values."""
|
||||||
|
|
||||||
|
def service_handler(call: ServiceCall) -> None:
|
||||||
|
"""Service handler coroutine."""
|
||||||
|
assert not call.return_values
|
||||||
|
return
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_handler,
|
||||||
|
)
|
||||||
|
result = await hass.services.async_call(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
service_data={},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert not result
|
||||||
|
|
||||||
|
|
||||||
async def test_config_defaults() -> None:
|
async def test_config_defaults() -> None:
|
||||||
"""Test config defaults."""
|
"""Test config defaults."""
|
||||||
hass = Mock()
|
hass = Mock()
|
||||||
|
|
Loading…
Add table
Reference in a new issue