From 84220e92ea5bc185a071a01e25d0eadcadc7e113 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 24 Jul 2023 03:12:21 -0400 Subject: [PATCH] Wrap internal ZHA exceptions in `HomeAssistantError`s (#97033) --- .../zha/core/cluster_handlers/__init__.py | 36 ++++++++++++++++--- tests/components/zha/test_cluster_handlers.py | 35 ++++++++++++++++++ tests/components/zha/test_cover.py | 11 +++--- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/zha/core/cluster_handlers/__init__.py b/homeassistant/components/zha/core/cluster_handlers/__init__.py index dcf8f2a525e..6c05ce2fe4f 100644 --- a/homeassistant/components/zha/core/cluster_handlers/__init__.py +++ b/homeassistant/components/zha/core/cluster_handlers/__init__.py @@ -2,10 +2,11 @@ from __future__ import annotations import asyncio +from collections.abc import Awaitable, Callable, Coroutine from enum import Enum -from functools import partialmethod +import functools import logging -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict import zigpy.exceptions import zigpy.util @@ -19,6 +20,7 @@ from zigpy.zcl.foundation import ( from homeassistant.const import ATTR_COMMAND from homeassistant.core import callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.dispatcher import async_dispatcher_send from ..const import ( @@ -45,8 +47,34 @@ if TYPE_CHECKING: from ..endpoint import Endpoint _LOGGER = logging.getLogger(__name__) +RETRYABLE_REQUEST_DECORATOR = zigpy.util.retryable_request(tries=3) -retry_request = zigpy.util.retryable_request(tries=3) + +_P = ParamSpec("_P") +_FuncType = Callable[_P, Awaitable[Any]] +_ReturnFuncType = Callable[_P, Coroutine[Any, Any, Any]] + + +def retry_request(func: _FuncType[_P]) -> _ReturnFuncType[_P]: + """Send a request with retries and wrap expected zigpy exceptions.""" + + @functools.wraps(func) + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: + try: + return await RETRYABLE_REQUEST_DECORATOR(func)(*args, **kwargs) + except asyncio.TimeoutError as exc: + raise HomeAssistantError( + "Failed to send request: device did not respond" + ) from exc + except zigpy.exceptions.ZigbeeException as exc: + message = "Failed to send request" + + if str(exc): + message = f"{message}: {exc}" + + raise HomeAssistantError(message) from exc + + return wrapper class AttrReportConfig(TypedDict, total=True): @@ -471,7 +499,7 @@ class ClusterHandler(LogMixin): rest = rest[ZHA_CLUSTER_HANDLER_READS_PER_REQ:] return result - get_attributes = partialmethod(_get_attributes, False) + get_attributes = functools.partialmethod(_get_attributes, False) def log(self, level, msg, *args, **kwargs): """Log a message.""" diff --git a/tests/components/zha/test_cluster_handlers.py b/tests/components/zha/test_cluster_handlers.py index 1897383b6c4..7e0e8eaab85 100644 --- a/tests/components/zha/test_cluster_handlers.py +++ b/tests/components/zha/test_cluster_handlers.py @@ -22,6 +22,7 @@ from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.endpoint import Endpoint import homeassistant.components.zha.core.registries as registries from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError from .common import get_zha_gateway, make_zcl_header from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE @@ -831,3 +832,37 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None: zha_endpoint.add_all_cluster_handlers() assert "missing_attr" in caplog.text + + +# parametrize side effects: +@pytest.mark.parametrize( + ("side_effect", "expected_error"), + [ + (zigpy.exceptions.ZigbeeException(), "Failed to send request"), + ( + zigpy.exceptions.ZigbeeException("Zigbee exception"), + "Failed to send request: Zigbee exception", + ), + (asyncio.TimeoutError(), "Failed to send request: device did not respond"), + ], +) +async def test_retry_request( + side_effect: Exception | None, expected_error: str | None +) -> None: + """Test the `retry_request` decorator's handling of zigpy-internal exceptions.""" + + async def func(arg1: int, arg2: int) -> int: + assert arg1 == 1 + assert arg2 == 2 + + raise side_effect + + func = mock.AsyncMock(wraps=func) + decorated_func = cluster_handlers.retry_request(func) + + with pytest.raises(HomeAssistantError) as exc: + await decorated_func(1, arg2=2) + + assert func.await_count == 3 + assert isinstance(exc.value, HomeAssistantError) + assert str(exc.value) == expected_error diff --git a/tests/components/zha/test_cover.py b/tests/components/zha/test_cover.py index d1003418487..7c4198bd881 100644 --- a/tests/components/zha/test_cover.py +++ b/tests/components/zha/test_cover.py @@ -26,6 +26,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import CoreState, HomeAssistant, State +from homeassistant.exceptions import HomeAssistantError from .common import ( async_enable_traffic, @@ -236,7 +237,7 @@ async def test_shade( # close from UI command fails with patch("zigpy.zcl.Cluster.request", side_effect=asyncio.TimeoutError): - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(HomeAssistantError): await hass.services.async_call( COVER_DOMAIN, SERVICE_CLOSE_COVER, @@ -261,7 +262,7 @@ async def test_shade( assert ATTR_CURRENT_POSITION not in hass.states.get(entity_id).attributes await send_attributes_report(hass, cluster_level, {0: 0}) with patch("zigpy.zcl.Cluster.request", side_effect=asyncio.TimeoutError): - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(HomeAssistantError): await hass.services.async_call( COVER_DOMAIN, SERVICE_OPEN_COVER, @@ -285,7 +286,7 @@ async def test_shade( # set position UI command fails with patch("zigpy.zcl.Cluster.request", side_effect=asyncio.TimeoutError): - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(HomeAssistantError): await hass.services.async_call( COVER_DOMAIN, SERVICE_SET_COVER_POSITION, @@ -326,7 +327,7 @@ async def test_shade( # test cover stop with patch("zigpy.zcl.Cluster.request", side_effect=asyncio.TimeoutError): - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(HomeAssistantError): await hass.services.async_call( COVER_DOMAIN, SERVICE_STOP_COVER, @@ -395,7 +396,7 @@ async def test_keen_vent( p2 = patch.object(cluster_level, "request", return_value=[4, 0]) with p1, p2: - with pytest.raises(asyncio.TimeoutError): + with pytest.raises(HomeAssistantError): await hass.services.async_call( COVER_DOMAIN, SERVICE_OPEN_COVER,