Shield service call from cancellation on REST API connection loss (#102657)
* Shield service call from cancellation on connection loss * add test for timeout * Apply suggestions from code review * Apply suggestions from code review * fix merge * Apply suggestions from code review
This commit is contained in:
parent
4a4d2ad743
commit
d18b2d8748
2 changed files with 36 additions and 4 deletions
|
@ -1,6 +1,6 @@
|
||||||
"""Rest API for Home Assistant."""
|
"""Rest API for Home Assistant."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import timeout
|
from asyncio import shield, timeout
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
@ -62,6 +62,7 @@ ATTR_VERSION = "version"
|
||||||
DOMAIN = "api"
|
DOMAIN = "api"
|
||||||
STREAM_PING_PAYLOAD = "ping"
|
STREAM_PING_PAYLOAD = "ping"
|
||||||
STREAM_PING_INTERVAL = 50 # seconds
|
STREAM_PING_INTERVAL = 50 # seconds
|
||||||
|
SERVICE_WAIT_TIMEOUT = 10
|
||||||
|
|
||||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||||
|
|
||||||
|
@ -388,11 +389,17 @@ class APIDomainServicesView(HomeAssistantView):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await hass.services.async_call(
|
async with timeout(SERVICE_WAIT_TIMEOUT):
|
||||||
|
# shield the service call from cancellation on connection drop
|
||||||
|
await shield(
|
||||||
|
hass.services.async_call(
|
||||||
domain, service, data, blocking=True, context=context
|
domain, service, data, blocking=True, context=context
|
||||||
)
|
)
|
||||||
|
)
|
||||||
except (vol.Invalid, ServiceNotFound) as ex:
|
except (vol.Invalid, ServiceNotFound) as ex:
|
||||||
raise HTTPBadRequest() from ex
|
raise HTTPBadRequest() from ex
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
cancel_listen()
|
cancel_listen()
|
||||||
|
|
||||||
|
|
|
@ -352,6 +352,31 @@ async def test_api_call_service_with_data(
|
||||||
assert state["attributes"] == {"data": 1}
|
assert state["attributes"] == {"data": 1}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_api_call_service_timeout(
|
||||||
|
hass: HomeAssistant, mock_api_client: TestClient
|
||||||
|
) -> None:
|
||||||
|
"""Test if the API does not fail on long running services."""
|
||||||
|
test_value = []
|
||||||
|
|
||||||
|
fut = hass.loop.create_future()
|
||||||
|
|
||||||
|
async def listener(service_call):
|
||||||
|
"""Wait and return after mock_api_client.post finishes."""
|
||||||
|
value = await fut
|
||||||
|
test_value.append(value)
|
||||||
|
|
||||||
|
hass.services.async_register("test_domain", "test_service", listener)
|
||||||
|
|
||||||
|
with patch("homeassistant.components.api.SERVICE_WAIT_TIMEOUT", 0):
|
||||||
|
await mock_api_client.post("/api/services/test_domain/test_service")
|
||||||
|
assert len(test_value) == 0
|
||||||
|
fut.set_result(1)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(test_value) == 1
|
||||||
|
assert test_value[0] == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> None:
|
async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> None:
|
||||||
"""Test the template API."""
|
"""Test the template API."""
|
||||||
hass.states.async_set("sensor.temperature", 10)
|
hass.states.async_set("sensor.temperature", 10)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue