Shield service call from cancellation on REST API connection loss (#102657)

* Shield service call from cancellation on connection loss

* add test for timeout

* Apply suggestions from code review

* Apply suggestions from code review

* fix merge

* Apply suggestions from code review
This commit is contained in:
Denis Shulyaka 2023-11-02 14:58:26 +03:00 committed by GitHub
parent 4a4d2ad743
commit d18b2d8748
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 4 deletions

View file

@ -1,6 +1,6 @@
"""Rest API for Home Assistant.""" """Rest API for Home Assistant."""
import asyncio import asyncio
from asyncio import timeout from asyncio import shield, timeout
from collections.abc import Collection from collections.abc import Collection
from functools import lru_cache from functools import lru_cache
from http import HTTPStatus from http import HTTPStatus
@ -62,6 +62,7 @@ ATTR_VERSION = "version"
DOMAIN = "api" DOMAIN = "api"
STREAM_PING_PAYLOAD = "ping" STREAM_PING_PAYLOAD = "ping"
STREAM_PING_INTERVAL = 50 # seconds STREAM_PING_INTERVAL = 50 # seconds
SERVICE_WAIT_TIMEOUT = 10
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@ -388,11 +389,17 @@ class APIDomainServicesView(HomeAssistantView):
) )
try: try:
await hass.services.async_call( async with timeout(SERVICE_WAIT_TIMEOUT):
domain, service, data, blocking=True, context=context # shield the service call from cancellation on connection drop
) await shield(
hass.services.async_call(
domain, service, data, blocking=True, context=context
)
)
except (vol.Invalid, ServiceNotFound) as ex: except (vol.Invalid, ServiceNotFound) as ex:
raise HTTPBadRequest() from ex raise HTTPBadRequest() from ex
except TimeoutError:
pass
finally: finally:
cancel_listen() cancel_listen()

View file

@ -352,6 +352,31 @@ async def test_api_call_service_with_data(
assert state["attributes"] == {"data": 1} assert state["attributes"] == {"data": 1}
async def test_api_call_service_timeout(
hass: HomeAssistant, mock_api_client: TestClient
) -> None:
"""Test if the API does not fail on long running services."""
test_value = []
fut = hass.loop.create_future()
async def listener(service_call):
"""Wait and return after mock_api_client.post finishes."""
value = await fut
test_value.append(value)
hass.services.async_register("test_domain", "test_service", listener)
with patch("homeassistant.components.api.SERVICE_WAIT_TIMEOUT", 0):
await mock_api_client.post("/api/services/test_domain/test_service")
assert len(test_value) == 0
fut.set_result(1)
await hass.async_block_till_done()
assert len(test_value) == 1
assert test_value[0] == 1
async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> None: async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> None:
"""Test the template API.""" """Test the template API."""
hass.states.async_set("sensor.temperature", 10) hass.states.async_set("sensor.temperature", 10)