Add target to service call API (#45898)
* Add target to service call API * Fix _async_call_service_step * CONF_SERVICE_ENTITY_ID overrules target * Move merging up before processing schema * Restore services.yaml * Add test
This commit is contained in:
parent
7d2d98fc3c
commit
4b493c5ab9
6 changed files with 82 additions and 16 deletions
|
@ -378,7 +378,7 @@ class APIDomainServicesView(HomeAssistantView):
|
|||
with AsyncTrackStates(hass) as changed_states:
|
||||
try:
|
||||
await hass.services.async_call(
|
||||
domain, service, data, True, self.context(request)
|
||||
domain, service, data, blocking=True, context=self.context(request)
|
||||
)
|
||||
except (vol.Invalid, ServiceNotFound) as ex:
|
||||
raise HTTPBadRequest() from ex
|
||||
|
|
|
@ -121,6 +121,7 @@ def handle_unsubscribe_events(hass, connection, msg):
|
|||
vol.Required("type"): "call_service",
|
||||
vol.Required("domain"): str,
|
||||
vol.Required("service"): str,
|
||||
vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS,
|
||||
vol.Optional("service_data"): dict,
|
||||
}
|
||||
)
|
||||
|
@ -139,6 +140,7 @@ async def handle_call_service(hass, connection, msg):
|
|||
msg.get("service_data"),
|
||||
blocking,
|
||||
context,
|
||||
target=msg.get("target"),
|
||||
)
|
||||
connection.send_message(
|
||||
messages.result_message(msg["id"], {"context": context})
|
||||
|
|
|
@ -1358,6 +1358,7 @@ class ServiceRegistry:
|
|||
blocking: bool = False,
|
||||
context: Optional[Context] = None,
|
||||
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
||||
target: Optional[Dict] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Call a service.
|
||||
|
@ -1365,7 +1366,9 @@ class ServiceRegistry:
|
|||
See description of async_call for details.
|
||||
"""
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.async_call(domain, service, service_data, blocking, context, limit),
|
||||
self.async_call(
|
||||
domain, service, service_data, blocking, context, limit, target
|
||||
),
|
||||
self._hass.loop,
|
||||
).result()
|
||||
|
||||
|
@ -1377,6 +1380,7 @@ class ServiceRegistry:
|
|||
blocking: bool = False,
|
||||
context: Optional[Context] = None,
|
||||
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
||||
target: Optional[Dict] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Call a service.
|
||||
|
@ -1404,6 +1408,9 @@ class ServiceRegistry:
|
|||
except KeyError:
|
||||
raise ServiceNotFound(domain, service) from None
|
||||
|
||||
if target:
|
||||
service_data.update(target)
|
||||
|
||||
if handler.schema:
|
||||
try:
|
||||
processed_data = handler.schema(service_data)
|
||||
|
|
|
@ -433,14 +433,14 @@ class _ScriptRun:
|
|||
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
|
||||
domain, service_name, service_data = service.async_prepare_call_from_config(
|
||||
params = service.async_prepare_call_from_config(
|
||||
self._hass, self._action, self._variables
|
||||
)
|
||||
|
||||
running_script = (
|
||||
domain == "automation"
|
||||
and service_name == "trigger"
|
||||
or domain in ("python_script", "script")
|
||||
params["domain"] == "automation"
|
||||
and params["service_name"] == "trigger"
|
||||
or params["domain"] in ("python_script", "script")
|
||||
)
|
||||
# If this might start a script then disable the call timeout.
|
||||
# Otherwise use the normal service call limit.
|
||||
|
@ -451,9 +451,7 @@ class _ScriptRun:
|
|||
|
||||
service_task = self._hass.async_create_task(
|
||||
self._hass.services.async_call(
|
||||
domain,
|
||||
service_name,
|
||||
service_data,
|
||||
**params,
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
limit=limit,
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
@ -70,6 +71,15 @@ _LOGGER = logging.getLogger(__name__)
|
|||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||
|
||||
|
||||
class ServiceParams(TypedDict):
|
||||
"""Type for service call parameters."""
|
||||
|
||||
domain: str
|
||||
service: str
|
||||
service_data: Dict[str, Any]
|
||||
target: Optional[Dict]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SelectedEntities:
|
||||
"""Class to hold the selected entities."""
|
||||
|
@ -136,7 +146,7 @@ async def async_call_from_config(
|
|||
raise
|
||||
_LOGGER.error(ex)
|
||||
else:
|
||||
await hass.services.async_call(*params, blocking, context)
|
||||
await hass.services.async_call(**params, blocking=blocking, context=context)
|
||||
|
||||
|
||||
@ha.callback
|
||||
|
@ -146,7 +156,7 @@ def async_prepare_call_from_config(
|
|||
config: ConfigType,
|
||||
variables: TemplateVarsType = None,
|
||||
validate_config: bool = False,
|
||||
) -> Tuple[str, str, Dict[str, Any]]:
|
||||
) -> ServiceParams:
|
||||
"""Prepare to call a service based on a config hash."""
|
||||
if validate_config:
|
||||
try:
|
||||
|
@ -177,10 +187,9 @@ def async_prepare_call_from_config(
|
|||
|
||||
domain, service = domain_service.split(".", 1)
|
||||
|
||||
service_data = {}
|
||||
target = config.get(CONF_TARGET)
|
||||
|
||||
if CONF_TARGET in config:
|
||||
service_data.update(config[CONF_TARGET])
|
||||
service_data = {}
|
||||
|
||||
for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]:
|
||||
if conf not in config:
|
||||
|
@ -192,9 +201,17 @@ def async_prepare_call_from_config(
|
|||
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex
|
||||
|
||||
if CONF_SERVICE_ENTITY_ID in config:
|
||||
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
||||
if target:
|
||||
target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
||||
else:
|
||||
target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]}
|
||||
|
||||
return domain, service, service_data
|
||||
return {
|
||||
"domain": domain,
|
||||
"service": service,
|
||||
"service_data": service_data,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -431,6 +448,7 @@ async def async_get_all_descriptions(
|
|||
|
||||
description = descriptions_cache[cache_key] = {
|
||||
"description": yaml_description.get("description", ""),
|
||||
"target": yaml_description.get("target"),
|
||||
"fields": yaml_description.get("fields", {}),
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,47 @@ async def test_call_service(hass, websocket_client):
|
|||
assert call.data == {"hello": "world"}
|
||||
|
||||
|
||||
async def test_call_service_target(hass, websocket_client):
|
||||
"""Test call service command with target."""
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def service_call(call):
|
||||
calls.append(call)
|
||||
|
||||
hass.services.async_register("domain_test", "test_service", service_call)
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "call_service",
|
||||
"domain": "domain_test",
|
||||
"service": "test_service",
|
||||
"service_data": {"hello": "world"},
|
||||
"target": {
|
||||
"entity_id": ["entity.one", "entity.two"],
|
||||
"device_id": "deviceid",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 5
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
|
||||
assert call.domain == "domain_test"
|
||||
assert call.service == "test_service"
|
||||
assert call.data == {
|
||||
"hello": "world",
|
||||
"entity_id": ["entity.one", "entity.two"],
|
||||
"device_id": ["deviceid"],
|
||||
}
|
||||
|
||||
|
||||
async def test_call_service_not_found(hass, websocket_client):
|
||||
"""Test call service command."""
|
||||
await websocket_client.send_json(
|
||||
|
|
Loading…
Add table
Reference in a new issue