From 618b666126a885bfb0ae44571bec0c44ff5bdada Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Fri, 10 Nov 2023 15:44:43 -0500 Subject: [PATCH] Add support for responses to `call_service` WS cmd (#98610) * Add support for responses to call_service WS cmd * Revert ServiceNotFound removal and add a parameter for return_response * fix type * fix tests * remove exception handling that was added * Revert unnecessary modifications * Use kwargs --- .../components/websocket_api/commands.py | 29 ++++--- tests/common.py | 7 +- .../components/websocket_api/test_commands.py | 79 ++++++++++++++++++- 3 files changed, 102 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 471bbc4745a..18688914e8b 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -18,7 +18,14 @@ from homeassistant.const import ( MATCH_ALL, SIGNAL_BOOTSTRAP_INTEGRATIONS, ) -from homeassistant.core import Context, Event, HomeAssistant, State, callback +from homeassistant.core import ( + Context, + Event, + HomeAssistant, + ServiceResponse, + State, + callback, +) from homeassistant.exceptions import ( HomeAssistantError, ServiceNotFound, @@ -213,6 +220,7 @@ def handle_unsubscribe_events( vol.Required("service"): str, vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS, vol.Optional("service_data"): dict, + vol.Optional("return_response", default=False): bool, } ) @decorators.async_response @@ -220,7 +228,6 @@ async def handle_call_service( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Handle call service command.""" - blocking = True # We do not support templates. target = msg.get("target") if template.is_complex(target): @@ -228,15 +235,19 @@ async def handle_call_service( try: context = connection.context(msg) - await hass.services.async_call( - msg["domain"], - msg["service"], - msg.get("service_data"), - blocking, - context, + response = await hass.services.async_call( + domain=msg["domain"], + service=msg["service"], + service_data=msg.get("service_data"), + blocking=True, + context=context, target=target, + return_response=msg["return_response"], ) - connection.send_result(msg["id"], {"context": context}) + result: dict[str, Context | ServiceResponse] = {"context": context} + if msg["return_response"]: + result["response"] = response + connection.send_result(msg["id"], result) except ServiceNotFound as err: if err.domain == msg["domain"] and err.service == msg["service"]: connection.send_error( diff --git a/tests/common.py b/tests/common.py index cd522aa3320..1737eae21e6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -307,8 +307,11 @@ def async_mock_service( calls.append(call) return response - if supports_response is None and response is not None: - supports_response = SupportsResponse.OPTIONAL + if supports_response is None: + if response is not None: + supports_response = SupportsResponse.OPTIONAL + else: + supports_response = SupportsResponse.NONE hass.services.async_register( domain, diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 34424545666..a9551310c2a 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -18,7 +18,7 @@ from homeassistant.components.websocket_api.auth import ( ) from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATIONS -from homeassistant.core import Context, HomeAssistant, State, callback +from homeassistant.core import Context, HomeAssistant, State, SupportsResponse, callback from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.helpers import device_registry as dr from homeassistant.helpers.dispatcher import async_dispatcher_send @@ -183,14 +183,76 @@ async def test_call_service( assert call.context.as_dict() == msg["result"]["context"] +async def test_return_response_error(hass: HomeAssistant, websocket_client) -> None: + """Test return_response=True errors when service has no response.""" + hass.services.async_register( + "domain_test", "test_service_with_no_response", lambda x: None + ) + await websocket_client.send_json( + { + "id": 8, + "type": "call_service", + "domain": "domain_test", + "service": "test_service_with_no_response", + "service_data": {"hello": "world"}, + "return_response": True, + }, + ) + msg = await websocket_client.receive_json() + + assert msg["id"] == 8 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == "unknown_error" + + @pytest.mark.parametrize("command", ("call_service", "call_service_action")) async def test_call_service_blocking( hass: HomeAssistant, websocket_client: MockHAClientWebSocket, command ) -> None: """Test call service commands block, except for homeassistant restart / stop.""" + async_mock_service( + hass, + "domain_test", + "test_service", + response={"hello": "world"}, + supports_response=SupportsResponse.OPTIONAL, + ) with patch( "homeassistant.core.ServiceRegistry.async_call", autospec=True ) as mock_call: + mock_call.return_value = {"foo": "bar"} + await websocket_client.send_json( + { + "id": 4, + "type": "call_service", + "domain": "domain_test", + "service": "test_service", + "service_data": {"hello": "world"}, + "return_response": True, + }, + ) + msg = await websocket_client.receive_json() + + assert msg["id"] == 4 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"]["response"] == {"foo": "bar"} + mock_call.assert_called_once_with( + ANY, + "domain_test", + "test_service", + {"hello": "world"}, + blocking=True, + context=ANY, + target=ANY, + return_response=True, + ) + + with patch( + "homeassistant.core.ServiceRegistry.async_call", autospec=True + ) as mock_call: + mock_call.return_value = None await websocket_client.send_json( { "id": 5, @@ -213,11 +275,14 @@ async def test_call_service_blocking( blocking=True, context=ANY, target=ANY, + return_response=False, ) + async_mock_service(hass, "homeassistant", "test_service") with patch( "homeassistant.core.ServiceRegistry.async_call", autospec=True ) as mock_call: + mock_call.return_value = None await websocket_client.send_json( { "id": 6, @@ -239,11 +304,14 @@ async def test_call_service_blocking( blocking=True, context=ANY, target=ANY, + return_response=False, ) + async_mock_service(hass, "homeassistant", "restart") with patch( "homeassistant.core.ServiceRegistry.async_call", autospec=True ) as mock_call: + mock_call.return_value = None await websocket_client.send_json( { "id": 7, @@ -258,7 +326,14 @@ async def test_call_service_blocking( assert msg["type"] == const.TYPE_RESULT assert msg["success"] mock_call.assert_called_once_with( - ANY, "homeassistant", "restart", ANY, blocking=True, context=ANY, target=ANY + ANY, + "homeassistant", + "restart", + ANY, + blocking=True, + context=ANY, + target=ANY, + return_response=False, )