Add target to service call API (#45898)

* Add target to service call API

* Fix _async_call_service_step

* CONF_SERVICE_ENTITY_ID overrules target

* Move merging up before processing schema

* Restore services.yaml

* Add test
This commit is contained in:
Bram Kragten 2021-02-10 12:42:28 +01:00 committed by GitHub
parent 7d2d98fc3c
commit 4b493c5ab9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 16 deletions

View file

@ -378,7 +378,7 @@ class APIDomainServicesView(HomeAssistantView):
with AsyncTrackStates(hass) as changed_states:
try:
await hass.services.async_call(
domain, service, data, True, self.context(request)
domain, service, data, blocking=True, context=self.context(request)
)
except (vol.Invalid, ServiceNotFound) as ex:
raise HTTPBadRequest() from ex

View file

@ -121,6 +121,7 @@ def handle_unsubscribe_events(hass, connection, msg):
vol.Required("type"): "call_service",
vol.Required("domain"): str,
vol.Required("service"): str,
vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS,
vol.Optional("service_data"): dict,
}
)
@ -139,6 +140,7 @@ async def handle_call_service(hass, connection, msg):
msg.get("service_data"),
blocking,
context,
target=msg.get("target"),
)
connection.send_message(
messages.result_message(msg["id"], {"context": context})

View file

@ -1358,6 +1358,7 @@ class ServiceRegistry:
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
"""
Call a service.
@ -1365,7 +1366,9 @@ class ServiceRegistry:
See description of async_call for details.
"""
return asyncio.run_coroutine_threadsafe(
self.async_call(domain, service, service_data, blocking, context, limit),
self.async_call(
domain, service, service_data, blocking, context, limit, target
),
self._hass.loop,
).result()
@ -1377,6 +1380,7 @@ class ServiceRegistry:
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
"""
Call a service.
@ -1404,6 +1408,9 @@ class ServiceRegistry:
except KeyError:
raise ServiceNotFound(domain, service) from None
if target:
service_data.update(target)
if handler.schema:
try:
processed_data = handler.schema(service_data)

View file

@ -433,14 +433,14 @@ class _ScriptRun:
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
domain, service_name, service_data = service.async_prepare_call_from_config(
params = service.async_prepare_call_from_config(
self._hass, self._action, self._variables
)
running_script = (
domain == "automation"
and service_name == "trigger"
or domain in ("python_script", "script")
params["domain"] == "automation"
and params["service_name"] == "trigger"
or params["domain"] in ("python_script", "script")
)
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
@ -451,9 +451,7 @@ class _ScriptRun:
service_task = self._hass.async_create_task(
self._hass.services.async_call(
domain,
service_name,
service_data,
**params,
blocking=True,
context=self._context,
limit=limit,

View file

@ -14,6 +14,7 @@ from typing import (
Optional,
Set,
Tuple,
TypedDict,
Union,
cast,
)
@ -70,6 +71,15 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
class ServiceParams(TypedDict):
"""Type for service call parameters."""
domain: str
service: str
service_data: Dict[str, Any]
target: Optional[Dict]
@dataclasses.dataclass
class SelectedEntities:
"""Class to hold the selected entities."""
@ -136,7 +146,7 @@ async def async_call_from_config(
raise
_LOGGER.error(ex)
else:
await hass.services.async_call(*params, blocking, context)
await hass.services.async_call(**params, blocking=blocking, context=context)
@ha.callback
@ -146,7 +156,7 @@ def async_prepare_call_from_config(
config: ConfigType,
variables: TemplateVarsType = None,
validate_config: bool = False,
) -> Tuple[str, str, Dict[str, Any]]:
) -> ServiceParams:
"""Prepare to call a service based on a config hash."""
if validate_config:
try:
@ -177,10 +187,9 @@ def async_prepare_call_from_config(
domain, service = domain_service.split(".", 1)
service_data = {}
target = config.get(CONF_TARGET)
if CONF_TARGET in config:
service_data.update(config[CONF_TARGET])
service_data = {}
for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]:
if conf not in config:
@ -192,9 +201,17 @@ def async_prepare_call_from_config(
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex
if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
if target:
target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
else:
target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]}
return domain, service, service_data
return {
"domain": domain,
"service": service,
"service_data": service_data,
"target": target,
}
@bind_hass
@ -431,6 +448,7 @@ async def async_get_all_descriptions(
description = descriptions_cache[cache_key] = {
"description": yaml_description.get("description", ""),
"target": yaml_description.get("target"),
"fields": yaml_description.get("fields", {}),
}

View file

@ -52,6 +52,47 @@ async def test_call_service(hass, websocket_client):
assert call.data == {"hello": "world"}
async def test_call_service_target(hass, websocket_client):
"""Test call service command with target."""
calls = []
@callback
def service_call(call):
calls.append(call)
hass.services.async_register("domain_test", "test_service", service_call)
await websocket_client.send_json(
{
"id": 5,
"type": "call_service",
"domain": "domain_test",
"service": "test_service",
"service_data": {"hello": "world"},
"target": {
"entity_id": ["entity.one", "entity.two"],
"device_id": "deviceid",
},
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert len(calls) == 1
call = calls[0]
assert call.domain == "domain_test"
assert call.service == "test_service"
assert call.data == {
"hello": "world",
"entity_id": ["entity.one", "entity.two"],
"device_id": ["deviceid"],
}
async def test_call_service_not_found(hass, websocket_client):
"""Test call service command."""
await websocket_client.send_json(