Add support for services to return data (#94401)

* Add support for service calls with resopnse data.

Update the service calls to allow returning responses with data,
with an initial use case supporting basic service calls usable
within script.

* Revert enttiy platform/component changes

* Remove unnecessary comma diff

* Revert additional unnecessary changes

* Simplify service call

* Simplify and fix typing and revert whitespace

* Clarify typing intent

* Revert more entity service calls

* Revert additional entity service changes

* Set blocking=True for group notify service call

* Revert unnecessary changes

* Reverting more whitespace changes

* Revert more service changes

* Add test coverage for None return case

* Add parameter to service calls indicating return values were requested

* Update tests/test_core.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Add additional service call tests

* Update test comment

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2023-06-16 09:43:35 -07:00 committed by GitHub
parent 12129e9d21
commit 84c66b3cad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 184 additions and 22 deletions

View file

@ -66,7 +66,7 @@ class GroupNotifyPlatform(BaseNotificationService):
payload: dict[str, Any] = {ATTR_MESSAGE: message} payload: dict[str, Any] = {ATTR_MESSAGE: message}
payload.update({key: val for key, val in kwargs.items() if val}) payload.update({key: val for key, val in kwargs.items() if val})
tasks: list[asyncio.Task[bool | None]] = [] tasks: list[asyncio.Task[Any]] = []
for entity in self.entities: for entity in self.entities:
sending_payload = deepcopy(payload.copy()) sending_payload = deepcopy(payload.copy())
if (default_data := entity.get(ATTR_DATA)) is not None: if (default_data := entity.get(ATTR_DATA)) is not None:
@ -74,7 +74,7 @@ class GroupNotifyPlatform(BaseNotificationService):
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
self.hass.services.async_call( self.hass.services.async_call(
DOMAIN, entity[ATTR_SERVICE], sending_payload DOMAIN, entity[ATTR_SERVICE], sending_payload, blocking=True
) )
) )
) )

View file

@ -88,6 +88,7 @@ from .util.async_ import (
run_callback_threadsafe, run_callback_threadsafe,
shutdown_run_callback_threadsafe, shutdown_run_callback_threadsafe,
) )
from .util.json import JsonObjectType
from .util.read_only_dict import ReadOnlyDict from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager from .util.timeout import TimeoutManager
from .util.unit_system import ( from .util.unit_system import (
@ -130,6 +131,8 @@ DOMAIN = "homeassistant"
# How long to wait to log tasks that are blocking # How long to wait to log tasks that are blocking
BLOCK_LOG_TIMEOUT = 60 BLOCK_LOG_TIMEOUT = 60
ServiceResult = JsonObjectType | None
class ConfigSource(StrEnum): class ConfigSource(StrEnum):
"""Source of core configuration.""" """Source of core configuration."""
@ -1659,7 +1662,7 @@ class Service:
def __init__( def __init__(
self, self,
func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None], func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
schema: vol.Schema | None, schema: vol.Schema | None,
domain: str, domain: str,
service: str, service: str,
@ -1673,7 +1676,7 @@ class Service:
class ServiceCall: class ServiceCall:
"""Representation of a call to a service.""" """Representation of a call to a service."""
__slots__ = ["domain", "service", "data", "context"] __slots__ = ["domain", "service", "data", "context", "return_values"]
def __init__( def __init__(
self, self,
@ -1681,12 +1684,14 @@ class ServiceCall:
service: str, service: str,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
context: Context | None = None, context: Context | None = None,
return_values: bool = False,
) -> None: ) -> None:
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain.lower() self.domain = domain.lower()
self.service = service.lower() self.service = service.lower()
self.data = ReadOnlyDict(data or {}) self.data = ReadOnlyDict(data or {})
self.context = context or Context() self.context = context or Context()
self.return_values = return_values
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return the representation of the service.""" """Return the representation of the service."""
@ -1731,7 +1736,10 @@ class ServiceRegistry:
self, self,
domain: str, domain: str,
service: str, service: str,
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None], service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResult] | None,
],
schema: vol.Schema | None = None, schema: vol.Schema | None = None,
) -> None: ) -> None:
"""Register a service. """Register a service.
@ -1747,7 +1755,9 @@ class ServiceRegistry:
self, self,
domain: str, domain: str,
service: str, service: str,
service_func: Callable[[ServiceCall], Coroutine[Any, Any, None] | None], service_func: Callable[
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
],
schema: vol.Schema | None = None, schema: vol.Schema | None = None,
) -> None: ) -> None:
"""Register a service. """Register a service.
@ -1805,13 +1815,22 @@ class ServiceRegistry:
blocking: bool = False, blocking: bool = False,
context: Context | None = None, context: Context | None = None,
target: dict[str, Any] | None = None, target: dict[str, Any] | None = None,
) -> bool | None: return_values: bool = False,
) -> ServiceResult:
"""Call a service. """Call a service.
See description of async_call for details. See description of async_call for details.
""" """
return asyncio.run_coroutine_threadsafe( return asyncio.run_coroutine_threadsafe(
self.async_call(domain, service, service_data, blocking, context, target), self.async_call(
domain,
service,
service_data,
blocking,
context,
target,
return_values,
),
self._hass.loop, self._hass.loop,
).result() ).result()
@ -1823,11 +1842,16 @@ class ServiceRegistry:
blocking: bool = False, blocking: bool = False,
context: Context | None = None, context: Context | None = None,
target: dict[str, Any] | None = None, target: dict[str, Any] | None = None,
) -> None: return_values: bool = False,
) -> ServiceResult:
"""Call a service. """Call a service.
Specify blocking=True to wait until service is executed. Specify blocking=True to wait until service is executed.
If return_values=True, indicates that the caller can consume return values
from the service, if any. Return values are a dict that can be returned by the
standard JSON serialization process. Return values can only be used with blocking=True.
This method will fire an event to indicate the service has been called. This method will fire an event to indicate the service has been called.
Because the service is sent as an event you are not allowed to use Because the service is sent as an event you are not allowed to use
@ -1840,6 +1864,9 @@ class ServiceRegistry:
context = context or Context() context = context or Context()
service_data = service_data or {} service_data = service_data or {}
if return_values and not blocking:
raise ValueError("Invalid argument return_values=True when blocking=False")
try: try:
handler = self._services[domain][service] handler = self._services[domain][service]
except KeyError: except KeyError:
@ -1862,7 +1889,9 @@ class ServiceRegistry:
else: else:
processed_data = service_data processed_data = service_data
service_call = ServiceCall(domain, service, processed_data, context) service_call = ServiceCall(
domain, service, processed_data, context, return_values
)
self._hass.bus.async_fire( self._hass.bus.async_fire(
EVENT_CALL_SERVICE, EVENT_CALL_SERVICE,
@ -1877,13 +1906,20 @@ class ServiceRegistry:
coro = self._execute_service(handler, service_call) coro = self._execute_service(handler, service_call)
if not blocking: if not blocking:
self._run_service_in_background(coro, service_call) self._run_service_in_background(coro, service_call)
return return None
await coro response_data = await coro
if not return_values:
return None
if not isinstance(response_data, dict):
raise HomeAssistantError(
f"Service response data expected a dictionary, was {type(response_data)}"
)
return response_data
def _run_service_in_background( def _run_service_in_background(
self, self,
coro_or_task: Coroutine[Any, Any, None] | asyncio.Task[None], coro_or_task: Coroutine[Any, Any, Any] | asyncio.Task[Any],
service_call: ServiceCall, service_call: ServiceCall,
) -> None: ) -> None:
"""Run service call in background, catching and logging any exceptions.""" """Run service call in background, catching and logging any exceptions."""
@ -1909,18 +1945,21 @@ class ServiceRegistry:
async def _execute_service( async def _execute_service(
self, handler: Service, service_call: ServiceCall self, handler: Service, service_call: ServiceCall
) -> None: ) -> ServiceResult:
"""Execute a service.""" """Execute a service."""
if handler.job.job_type == HassJobType.Coroutinefunction: if handler.job.job_type == HassJobType.Coroutinefunction:
await cast(Callable[[ServiceCall], Awaitable[None]], handler.job.target)( return await cast(
Callable[[ServiceCall], Awaitable[ServiceResult]],
handler.job.target,
)(service_call)
if handler.job.job_type == HassJobType.Callback:
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
service_call service_call
) )
elif handler.job.job_type == HassJobType.Callback: return await self._hass.async_add_executor_job(
cast(Callable[[ServiceCall], None], handler.job.target)(service_call) cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
else: service_call,
await self._hass.async_add_executor_job( )
cast(Callable[[ServiceCall], None], handler.job.target), service_call
)
class Config: class Config:

View file

@ -33,8 +33,9 @@ from homeassistant.const import (
__version__, __version__,
) )
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.core import HassJob, HomeAssistant, State from homeassistant.core import HassJob, HomeAssistant, ServiceCall, ServiceResult, State
from homeassistant.exceptions import ( from homeassistant.exceptions import (
HomeAssistantError,
InvalidEntityFormatError, InvalidEntityFormatError,
InvalidStateError, InvalidStateError,
MaxLengthExceeded, MaxLengthExceeded,
@ -1082,6 +1083,128 @@ async def test_serviceregistry_callback_service_raise_exception(
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_serviceregistry_return_values(hass: HomeAssistant) -> None:
"""Test service call for a service that has return values."""
def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_serviceregistry_async_return_values(hass: HomeAssistant) -> None:
"""Test service call for an async service that has return values."""
async def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return {"test-reply": "test-value1"}
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
assert result == {"test-reply": "test-value1"}
async def test_services_call_return_values_requires_blocking(
hass: HomeAssistant,
) -> None:
"""Test that non-blocking service calls cannot return values."""
async_mock_service(hass, "test_domain", "test_service")
with pytest.raises(ValueError, match="when blocking=False"):
await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=False,
return_values=True,
)
@pytest.mark.parametrize(
("return_value", "expected_error"),
[
(True, "expected a dictionary"),
(False, "expected a dictionary"),
(None, "expected a dictionary"),
("some-value", "expected a dictionary"),
(["some-list"], "expected a dictionary"),
],
)
async def test_serviceregistry_return_values_invalid(
hass: HomeAssistant, return_value: Any, expected_error: str
) -> None:
"""Test service call return values are not returned when there is no result schema."""
def service_handler(call: ServiceCall) -> ServiceResult:
"""Service handler coroutine."""
assert call.return_values
return return_value
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
with pytest.raises(HomeAssistantError, match=expected_error):
await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
return_values=True,
)
await hass.async_block_till_done()
async def test_serviceregistry_no_return_values(hass: HomeAssistant) -> None:
"""Test service call data when not asked for return values."""
def service_handler(call: ServiceCall) -> None:
"""Service handler coroutine."""
assert not call.return_values
return
hass.services.async_register(
"test_domain",
"test_service",
service_handler,
)
result = await hass.services.async_call(
"test_domain",
"test_service",
service_data={},
blocking=True,
)
await hass.async_block_till_done()
assert not result
async def test_config_defaults() -> None: async def test_config_defaults() -> None:
"""Test config defaults.""" """Test config defaults."""
hass = Mock() hass = Mock()