Improve service response data APIs (#94819)

* Improve service response data APIs

Make the API naming more consistent, and require registration that a
service supports response data so that we can better integrate with
the UI and avoid user confusion with better error messages.

* Improve test coverage

* Add an enum for registering response values

* Assign enum values

* Convert SupportsResponse to StrEnum

* Update service call test docstrings

* Add tiny missing full stop in comment

---------

Co-authored-by: Franck Nijhof <frenck@frenck.nl>
This commit is contained in:
Allen Porter 2023-06-20 06:24:31 -07:00 committed by GitHub
parent 4a8adae146
commit 30e8f806c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 147 additions and 78 deletions

View file

@ -131,7 +131,7 @@ DOMAIN = "homeassistant"
# How long to wait to log tasks that are blocking
BLOCK_LOG_TIMEOUT = 60
ServiceResult = JsonObjectType | None
ServiceResponse = JsonObjectType | None
class ConfigSource(StrEnum):
@ -1655,28 +1655,43 @@ class StateMachine:
)
class SupportsResponse(StrEnum):
"""Service call response configuration."""
NONE = "none"
"""The service does not support responses (the default)."""
OPTIONAL = "optional"
"""The service optionally returns response data when asked by the caller."""
ONLY = "only"
"""The service is read-only and the caller must always ask for response data."""
class Service:
"""Representation of a callable service."""
__slots__ = ["job", "schema", "domain", "service"]
__slots__ = ["job", "schema", "domain", "service", "supports_response"]
def __init__(
self,
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResult] | None],
func: Callable[[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None],
schema: vol.Schema | None,
domain: str,
service: str,
context: Context | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Initialize a service."""
self.job = HassJob(func, f"service {domain}.{service}")
self.schema = schema
self.supports_response = supports_response
class ServiceCall:
"""Representation of a call to a service."""
__slots__ = ["domain", "service", "data", "context", "return_values"]
__slots__ = ["domain", "service", "data", "context", "return_response"]
def __init__(
self,
@ -1684,14 +1699,14 @@ class ServiceCall:
service: str,
data: dict[str, Any] | None = None,
context: Context | None = None,
return_values: bool = False,
return_response: bool = False,
) -> None:
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = ReadOnlyDict(data or {})
self.context = context or Context()
self.return_values = return_values
self.return_response = return_response
def __repr__(self) -> str:
"""Return the representation of the service."""
@ -1738,7 +1753,7 @@ class ServiceRegistry:
service: str,
service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResult] | None,
Coroutine[Any, Any, ServiceResponse] | None,
],
schema: vol.Schema | None = None,
) -> None:
@ -1756,9 +1771,10 @@ class ServiceRegistry:
domain: str,
service: str,
service_func: Callable[
[ServiceCall], Coroutine[Any, Any, ServiceResult] | None
[ServiceCall], Coroutine[Any, Any, ServiceResponse] | None
],
schema: vol.Schema | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Register a service.
@ -1768,7 +1784,9 @@ class ServiceRegistry:
"""
domain = domain.lower()
service = service.lower()
service_obj = Service(service_func, schema, domain, service)
service_obj = Service(
service_func, schema, domain, service, supports_response=supports_response
)
if domain in self._services:
self._services[domain][service] = service_obj
@ -1815,8 +1833,8 @@ class ServiceRegistry:
blocking: bool = False,
context: Context | None = None,
target: dict[str, Any] | None = None,
return_values: bool = False,
) -> ServiceResult:
return_response: bool = False,
) -> ServiceResponse:
"""Call a service.
See description of async_call for details.
@ -1829,7 +1847,7 @@ class ServiceRegistry:
blocking,
context,
target,
return_values,
return_response,
),
self._hass.loop,
).result()
@ -1842,13 +1860,13 @@ class ServiceRegistry:
blocking: bool = False,
context: Context | None = None,
target: dict[str, Any] | None = None,
return_values: bool = False,
) -> ServiceResult:
return_response: bool = False,
) -> ServiceResponse:
"""Call a service.
Specify blocking=True to wait until service is executed.
If return_values=True, indicates that the caller can consume return values
If return_response=True, indicates that the caller can consume return values
from the service, if any. Return values are a dict that can be returned by the
standard JSON serialization process. Return values can only be used with blocking=True.
@ -1864,14 +1882,25 @@ class ServiceRegistry:
context = context or Context()
service_data = service_data or {}
if return_values and not blocking:
raise ValueError("Invalid argument return_values=True when blocking=False")
try:
handler = self._services[domain][service]
except KeyError:
raise ServiceNotFound(domain, service) from None
if return_response:
if not blocking:
raise ValueError(
"Invalid argument return_response=True when blocking=False"
)
if handler.supports_response == SupportsResponse.NONE:
raise ValueError(
"Invalid argument return_response=True when handler does not support responses"
)
elif handler.supports_response == SupportsResponse.ONLY:
raise ValueError(
"Service call requires responses but caller did not ask for responses"
)
if target:
service_data.update(target)
@ -1890,7 +1919,7 @@ class ServiceRegistry:
processed_data = service_data
service_call = ServiceCall(
domain, service, processed_data, context, return_values
domain, service, processed_data, context, return_response
)
self._hass.bus.async_fire(
@ -1909,7 +1938,7 @@ class ServiceRegistry:
return None
response_data = await coro
if not return_values:
if not return_response:
return None
if not isinstance(response_data, dict):
raise HomeAssistantError(
@ -1945,19 +1974,19 @@ class ServiceRegistry:
async def _execute_service(
self, handler: Service, service_call: ServiceCall
) -> ServiceResult:
) -> ServiceResponse:
"""Execute a service."""
if handler.job.job_type == HassJobType.Coroutinefunction:
return await cast(
Callable[[ServiceCall], Awaitable[ServiceResult]],
Callable[[ServiceCall], Awaitable[ServiceResponse]],
handler.job.target,
)(service_call)
if handler.job.job_type == HassJobType.Callback:
return cast(Callable[[ServiceCall], ServiceResult], handler.job.target)(
return cast(Callable[[ServiceCall], ServiceResponse], handler.job.target)(
service_call
)
return await self._hass.async_add_executor_job(
cast(Callable[[ServiceCall], ServiceResult], handler.job.target),
cast(Callable[[ServiceCall], ServiceResponse], handler.job.target),
service_call,
)